Source code for ananke.identification.opt_adjust

"""
Optimal adjustment sets finding
"""

import logging
import unittest

import numpy as np

import ananke
from ananke.graphs import ADMG, DAG

logger = logging.getLogger("opt_adjust")


[docs]def get_opt_set(G, vertex1, vertex2): """ function to get the optimal set from two vertices :param G: the graph :param vertex1: inference from vertex1, say treatment :param vertex2: inference to vertex2, say outcome """ causal_nodes = [ item for item in G.descendants([vertex1]) if item in G.ancestors([vertex2]) ] if len(causal_nodes) == 0: logger.info("No causal node!") return None causal_nodes.remove(vertex1) forb_set = [item for item in G.descendants(causal_nodes)] forb_set = forb_set + [ vertex1 ] # The forbidden set characterizes those covariates that may never be included into a valid adjustment set opt_set = [item for item in G.parents(causal_nodes) if item not in forb_set] return set(opt_set)
[docs]def get_opt_set_from_set(G, node_set1, node_set2): """ function to get the optimal set from two vertices sets :param G: the graph :param node_set1: set of inference from :param node_set2: set of inference to """ causal_nodes = [ item for item in G.descendants(node_set1) if item in G.ancestors(node_set2) ] causal_nodes.remove(node_set1) forb_set = [item for item in G.descendants(causal_nodes)] forb_set = forb_set + [node_set1] input_set = [item for item in G.vertices if item not in forb_set] for element in input_set: new_set = [item for item in input_set if item not in element] condition_set = new_set + [vertex1] flag_new_set = G.d_separated( element, node_set2, condition_set ) # check if d-separated if flag_new_set: input_set = new_set return set(input_set)
[docs]def get_min_set(G, input, vertex1, vertex2): """ function for minimal set :param G: DAG :param input: optimal set found using get_opt_set function :param vertex1: inference from vertex1, say treatment :param vertex2: inference to vertex2, say outcome """ if len(input) == 0: logger.info("Adjustment set is empty!") return input_set = [item for item in input] is_minimal_set = 0 # flag for the minimal set for element in G.children( [vertex1] ): # prepare for d-separation test, construct a SWIG G.delete_diedge(vertex1, element) for element in input_set: new_set = [item for item in input_set if item not in element] is_d_separated = G.d_separated( vertex1, vertex2, new_set ) # check if d-separated if is_d_separated: is_minimal_set = 1 final_set = get_min_set(G, new_set, vertex1, vertex2) if not is_minimal_set: # input set is minimal set return set(input_set) return set(final_set)
[docs]class OptAdjustment: def __init__(self, graph): self.graph = graph
[docs] def generate_opt_adjustment_set(self, vertex1, vertex2): """ Fits the binary nested model. Let N the number of observations, M the number of variables. :param self: DAG :param vertex1: inference from vertex1, say treatment :param vertex2: inference to vertex2, say outcome """ model = self.graph if type(model) == ADMG: logger.info("Found ADMG graph, not applicable") return None is_d_separated = model.d_separated(vertex1, vertex2, []) if is_d_separated: logger.info( "Tested vertices are independent so no set is needed to condition on" ) return [] opt_set = get_opt_set(model, vertex1, vertex2) if opt_set == None: logger.info("No adjustment set!") return None result = get_min_set(model, opt_set, vertex1, vertex2) return result