Discrete Models for Identification Algorithm Development

In causal inference we are often concerned with performing identification, which is determining whether a causal query is a function of the observed data. Algorithms have been developed to address these questions in a variety of settings. When developing new algorithms for identification, the first aspect is soundness - whether an algorithm correctly returns the causal effect.

In this notebook we outline some helpful tools intended to assist in the development of such algorithms.

We’ll first go through setting up causal models using ananke.models.BayesianNetwork, which supports both numerical and symbolic representations of conditional probability distributions through pgmpy.factors.discrete.TabularCPD and ananke.factors.SymCPD respectively.

Given this causal model, we can directly introduce an intervention, and then compute the causal query in the intervened distribution. Then, we can pass the observed data distribution from the causal model into a proposed identification algorithm, and verify that the result agrees with the computed truth.

This development pattern should be useful for both checking correctness of proposed identification algorithms, and also serve as a concrete implementation for understanding how causal identification works.

[27]:

import sympy as sp

from pgmpy.inference import VariableElimination
from pgmpy.models import BayesianNetwork

from ananke.identification import OneLineID
from ananke.models import bayesian_network
from ananke.estimation import empirical_plugin
from ananke.identification import oracle
from ananke.inference.variable_elimination import variable_elimination


Front-door graph

We first consider the simple case of the front-door graph. In this example we demonstrate the API which implements the ID algorithm (Shpitser and Pearl, 2006).

We initialize an ADMG representing this graph, and provide variable cardinalities.

[28]:

di_edges = [("A", "M"), ("M", "Y")]
bi_edges = [("A", "Y")]
graph = ADMG(vertices={"A": 2, "M": 2, "Y": 2}, di_edges=di_edges, bi_edges=bi_edges)
graph.draw()

[28]:


This distribution is an ADMG, which has bidirected edges representing unmeasured confounding. To continue, we convert this ADMG into its canonical DAG form, which replaces each bidirected edge with a single hidden variable.

[29]:

dag = graph.canonical_dag(cardinality=2)

[30]:

dag.draw()

[30]:


Next, we generate conditional probability distributions consistent with this graph. This provides a causal model with factorization

$p(A, M, Y, U_{A, Y}) = p(Y | M, U_{A,Y}) p(A | U_{A, Y}) p(M | A) p(U_{A, Y})$

We represent this as net.

[31]:

cpds = bayesian_network.generate_random_cpds(graph=dag, dir_conc=10)
net = bayesian_network.BayesianNetwork(graph=dag, cpds=cpds)


We are next interested in a causal model where $$do(A=1)$$ has been implemented. Then,

$p(Y(a=1), M(a=1), U_{A, Y}(a=1)) = p(Y | M, U_{A,Y}) p(M | A=1) p(U_{A, Y})$

We represent this as int_net.

[32]:

treatment_dict = {"A": 1}
outcome_dict = {"Y": 1}


In this model, the causal effect $$p(Y(a=1)=1)$$ is simply the marginal distribution evaluated at $$Y=1$$:

[33]:

truth = oracle.compute_effect_from_discrete_model(net, treatment_dict, outcome_dict)


We can also compute the causal effect directly from the observed data distribution.

[34]:

oid = OneLineID(graph, list(treatment_dict), list(outcome_dict))

[35]:

obs_dist = variable_elimination(net, ["A", "M", "Y"])

[36]:

result = empirical_plugin.estimate_effect_from_discrete_dist(
oid, obs_dist, treatment_dict, outcome_dict
)

INFO:ananke.estimation.empirical_plugin:implied district is {'M'}
INFO:ananke.estimation.empirical_plugin:fixing by q(Y|M, A)
INFO:ananke.estimation.empirical_plugin:fixing by q(A|Y)
INFO:ananke.estimation.empirical_plugin:implied district is {'Y'}
INFO:ananke.estimation.empirical_plugin:fixing by q(M|A)
INFO:ananke.estimation.empirical_plugin:fixing by q(A|M, Y)

[37]:

print("Effect computed through intervention of model:  ", truth)
print("Effect computed through identification algorithm: ", result)

Effect computed through intervention of model:   0.5971619133337924
Effect computed through identification algorithm:  0.5971619133337924


The result agrees with the truth up to floating point precision.

Using symbolic computation

We can also perform these computations using the ananke.factors.SymCPD, which allows each conditional probability distribution to be specified using symbols (rather than numerical values).

[38]:

cpds, all_vars = bayesian_network.create_symbolic_cpds(dag)
net = bayesian_network.BayesianNetwork(graph=dag, cpds=cpds)

[39]:

truth = oracle.compute_effect_from_discrete_model(net, treatment_dict, outcome_dict)

[40]:

truth

[40]:

$\displaystyle 0.5 q_{M 0 1} \cdot \left(1 - q_{Y 0 00}\right) + 0.5 q_{M 0 1} \cdot \left(1 - q_{Y 0 01}\right) + 0.5 \cdot \left(1 - q_{M 0 1}\right) \left(1 - q_{Y 0 10}\right) + 0.5 \cdot \left(1 - q_{M 0 1}\right) \left(1 - q_{Y 0 11}\right)$
[41]:

obs_dist = variable_elimination(net, ["A", "M", "Y"])

[42]:

result = empirical_plugin.estimate_effect_from_discrete_dist(
oid, obs_dist, treatment_dict, outcome_dict
)

INFO:ananke.estimation.empirical_plugin:implied district is {'M'}
INFO:ananke.estimation.empirical_plugin:fixing by q(Y|M, A)
INFO:ananke.estimation.empirical_plugin:fixing by q(A|Y)
INFO:ananke.estimation.empirical_plugin:implied district is {'Y'}
INFO:ananke.estimation.empirical_plugin:fixing by q(M|A)
INFO:ananke.estimation.empirical_plugin:fixing by q(A|M, Y)

[43]:

result

[43]:

$\displaystyle \frac{\left(q_{M 0 1} \left(\frac{q_{Y 0 00} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 01} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + q_{M 0 1} \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 00}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 01}\right)}{2}\right)\right) \left(\frac{q_{M 0 0} \left(\frac{q_{A 0 0} \cdot \left(1 - q_{Y 0 00}\right)}{2} + \frac{q_{A 0 1} \cdot \left(1 - q_{Y 0 01}\right)}{2}\right) \left(q_{M 0 0} \left(\frac{q_{A 0 0} q_{Y 0 00}}{2} + \frac{q_{A 0 1} q_{Y 0 01}}{2}\right) + q_{M 0 0} \left(\frac{q_{A 0 0} \cdot \left(1 - q_{Y 0 00}\right)}{2} + \frac{q_{A 0 1} \cdot \left(1 - q_{Y 0 01}\right)}{2}\right) + \left(1 - q_{M 0 0}\right) \left(\frac{q_{A 0 0} q_{Y 0 10}}{2} + \frac{q_{A 0 1} q_{Y 0 11}}{2}\right) + \left(1 - q_{M 0 0}\right) \left(\frac{q_{A 0 0} \cdot \left(1 - q_{Y 0 10}\right)}{2} + \frac{q_{A 0 1} \cdot \left(1 - q_{Y 0 11}\right)}{2}\right)\right)}{q_{M 0 0} \left(\frac{q_{A 0 0} q_{Y 0 00}}{2} + \frac{q_{A 0 1} q_{Y 0 01}}{2}\right) + q_{M 0 0} \left(\frac{q_{A 0 0} \cdot \left(1 - q_{Y 0 00}\right)}{2} + \frac{q_{A 0 1} \cdot \left(1 - q_{Y 0 01}\right)}{2}\right)} + \frac{q_{M 0 1} \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 00}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 01}\right)}{2}\right) \left(q_{M 0 1} \left(\frac{q_{Y 0 00} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 01} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + q_{M 0 1} \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 00}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 01}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{q_{Y 0 10} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 11} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 10}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 11}\right)}{2}\right)\right)}{q_{M 0 1} \left(\frac{q_{Y 0 00} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 01} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + q_{M 0 1} \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 00}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 01}\right)}{2}\right)}\right) \left(q_{M 0 0} \left(\frac{q_{A 0 0} q_{Y 0 00}}{2} + \frac{q_{A 0 1} q_{Y 0 01}}{2}\right) + q_{M 0 0} \left(\frac{q_{A 0 0} \cdot \left(1 - q_{Y 0 00}\right)}{2} + \frac{q_{A 0 1} \cdot \left(1 - q_{Y 0 01}\right)}{2}\right) + q_{M 0 1} \left(\frac{q_{Y 0 00} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 01} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + q_{M 0 1} \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 00}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 01}\right)}{2}\right) + \left(1 - q_{M 0 0}\right) \left(\frac{q_{A 0 0} q_{Y 0 10}}{2} + \frac{q_{A 0 1} q_{Y 0 11}}{2}\right) + \left(1 - q_{M 0 0}\right) \left(\frac{q_{A 0 0} \cdot \left(1 - q_{Y 0 10}\right)}{2} + \frac{q_{A 0 1} \cdot \left(1 - q_{Y 0 11}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{q_{Y 0 10} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 11} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 10}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 11}\right)}{2}\right)\right)}{q_{M 0 1} \left(\frac{q_{Y 0 00} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 01} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + q_{M 0 1} \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 00}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 01}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{q_{Y 0 10} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 11} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 10}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 11}\right)}{2}\right)} + \frac{\left(\left(1 - q_{M 0 1}\right) \left(\frac{q_{Y 0 10} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 11} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 10}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 11}\right)}{2}\right)\right) \left(\frac{\left(1 - q_{M 0 0}\right) \left(\frac{q_{A 0 0} \cdot \left(1 - q_{Y 0 10}\right)}{2} + \frac{q_{A 0 1} \cdot \left(1 - q_{Y 0 11}\right)}{2}\right) \left(q_{M 0 0} \left(\frac{q_{A 0 0} q_{Y 0 00}}{2} + \frac{q_{A 0 1} q_{Y 0 01}}{2}\right) + q_{M 0 0} \left(\frac{q_{A 0 0} \cdot \left(1 - q_{Y 0 00}\right)}{2} + \frac{q_{A 0 1} \cdot \left(1 - q_{Y 0 01}\right)}{2}\right) + \left(1 - q_{M 0 0}\right) \left(\frac{q_{A 0 0} q_{Y 0 10}}{2} + \frac{q_{A 0 1} q_{Y 0 11}}{2}\right) + \left(1 - q_{M 0 0}\right) \left(\frac{q_{A 0 0} \cdot \left(1 - q_{Y 0 10}\right)}{2} + \frac{q_{A 0 1} \cdot \left(1 - q_{Y 0 11}\right)}{2}\right)\right)}{\left(1 - q_{M 0 0}\right) \left(\frac{q_{A 0 0} q_{Y 0 10}}{2} + \frac{q_{A 0 1} q_{Y 0 11}}{2}\right) + \left(1 - q_{M 0 0}\right) \left(\frac{q_{A 0 0} \cdot \left(1 - q_{Y 0 10}\right)}{2} + \frac{q_{A 0 1} \cdot \left(1 - q_{Y 0 11}\right)}{2}\right)} + \frac{\left(1 - q_{M 0 1}\right) \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 10}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 11}\right)}{2}\right) \left(q_{M 0 1} \left(\frac{q_{Y 0 00} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 01} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + q_{M 0 1} \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 00}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 01}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{q_{Y 0 10} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 11} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 10}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 11}\right)}{2}\right)\right)}{\left(1 - q_{M 0 1}\right) \left(\frac{q_{Y 0 10} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 11} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 10}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 11}\right)}{2}\right)}\right) \left(q_{M 0 0} \left(\frac{q_{A 0 0} q_{Y 0 00}}{2} + \frac{q_{A 0 1} q_{Y 0 01}}{2}\right) + q_{M 0 0} \left(\frac{q_{A 0 0} \cdot \left(1 - q_{Y 0 00}\right)}{2} + \frac{q_{A 0 1} \cdot \left(1 - q_{Y 0 01}\right)}{2}\right) + q_{M 0 1} \left(\frac{q_{Y 0 00} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 01} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + q_{M 0 1} \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 00}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 01}\right)}{2}\right) + \left(1 - q_{M 0 0}\right) \left(\frac{q_{A 0 0} q_{Y 0 10}}{2} + \frac{q_{A 0 1} q_{Y 0 11}}{2}\right) + \left(1 - q_{M 0 0}\right) \left(\frac{q_{A 0 0} \cdot \left(1 - q_{Y 0 10}\right)}{2} + \frac{q_{A 0 1} \cdot \left(1 - q_{Y 0 11}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{q_{Y 0 10} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 11} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 10}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 11}\right)}{2}\right)\right)}{q_{M 0 1} \left(\frac{q_{Y 0 00} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 01} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + q_{M 0 1} \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 00}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 01}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{q_{Y 0 10} \cdot \left(1 - q_{A 0 0}\right)}{2} + \frac{q_{Y 0 11} \cdot \left(1 - q_{A 0 1}\right)}{2}\right) + \left(1 - q_{M 0 1}\right) \left(\frac{\left(1 - q_{A 0 0}\right) \left(1 - q_{Y 0 10}\right)}{2} + \frac{\left(1 - q_{A 0 1}\right) \left(1 - q_{Y 0 11}\right)}{2}\right)}$

To check that the truth and result agree, we can take the difference and simplify it, and if it is zero then they agree:

[44]:

sp.simplify(result - truth)

[44]:

$\displaystyle 0$

Conditional Ignorability

We demonstrate this example to show off how a user might compute identification queries using a lower level interface. This is useful if the developed identification algorithm does not align with the format of the ID algorithm.

[45]:

di_edges = [("A", "Y"), ("C", "Y"), ("C", "A")]
bi_edges = []
c_graph  = ADMG(vertices={"A": 2, "C": 2, "Y": 2}, di_edges=di_edges, bi_edges=bi_edges)
c_graph.draw()

[45]:

[46]:

c_dag = c_graph.canonical_dag(cardinality=2)

[47]:

cpds = bayesian_network.generate_random_cpds(graph=c_dag, dir_conc=10)
net = bayesian_network.BayesianNetwork(graph=c_dag, cpds=cpds)


We first set up an intervened version of the model, with the desired intervention $$do(A=1)$$.

[48]:

treatment_dict = {"A": 1}
outcome_dict = {"Y": 1}

int_net = net.copy()
int_net.fix(treatment_dict)

[48]:

<ananke.models.bayesian_network.BayesianNetwork at 0x287e44400>


Then we compute the marginal of $$Y$$ in this distribution which is the causal parameter of interest $$p(Y(a=1))$$. This gives us the truth.

[49]:

truth = variable_elimination(int_net, ['Y']).get_value(**outcome_dict)


To compute the causal effect another way, we use only the observed data distribution and implement the g-formula result:

$p(Y(a=1)) = \sum_C p(Y | A=1, C) p(C)$
[51]:

p_YAC = variable_elimination(net, ['Y', 'A', 'C'])
p_Y_AC = p_YAC.divide(p_YAC.marginalize(["Y"], inplace=False), inplace=False)
p_Y_A1C = p_Y_AC.reduce([("A", 1)], inplace=False)
p_C = p_YAC.marginalize(["A", "Y"], inplace=False)
p_Y_do_A1 = p_Y_A1C.product(p_C, inplace=False).marginalize(["C"], inplace=False)

[52]:

result = p_Y_do_A1.get_value(**outcome_dict)

[53]:

print("Effect computed through intervention of model:  ", truth)
print("Effect computed through identification algorithm: ", result)

Effect computed through intervention of model:   0.5890045340192529
Effect computed through identification algorithm:  0.589004534019253


Again, the result agrees with the truth up to floating point precision.