Source code for ananke.graphs.admg

"""
Class for acyclic directed mixed graphs (ADMGs) and conditional ADMGs (CADMGs).
"""
import copy
import itertools
import logging
import warnings
from typing import Union

from ananke.utils import powerset

from .ig import IG
from .sg import SG
from .ug import UG

logger = logging.getLogger(__name__)


[docs]class UndefinedADMGOperation(Exception): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs]def latent_project_single_vertex(vertex, graph): """ Latent project one vertex from graph :param vertex: Name of vertex to be projected :param graph: ADMG :returns: """ di_edges = graph.di_edges bi_edges = graph.bi_edges retained_vertices = set(graph.vertices) - {vertex} retained_di_edges = [ x for x in di_edges if x[0] in retained_vertices and x[1] in retained_vertices ] retained_bi_edges = [ x for x in bi_edges if x[0] in retained_vertices and x[1] in retained_vertices ] # Construct all directed edge projections new_di_edges = [ (p.name, c.name) for p in graph.vertices[vertex].parents for c in graph.vertices[vertex].children ] # Construct all bidirected edge projections new_bi_edges = [ tuple(sorted([s.name, c.name])) for s in graph.vertices[vertex].siblings for c in graph.vertices[vertex].children ] + [ tuple(sorted([a.name, b.name])) for a, b in itertools.combinations(graph.vertices[vertex].children, r=2) ] final_bi_edges = retained_bi_edges + new_bi_edges G = ADMG( vertices=set(retained_vertices), di_edges=set(retained_di_edges + new_di_edges), bi_edges=set(final_bi_edges), ) return G
[docs]class ADMG(SG): """ Class for creating and manipulating (conditional) acyclic directed mixed graphs (ADMGs/CADMGs). """ def __init__(self, vertices=[], di_edges=set(), bi_edges=set(), **kwargs): """ Constructor. :param vertices: iterable of names of vertices. :param di_edges: iterable of tuples of directed edges i.e. (X, Y) = X -> Y. :param bi_edges: iterable of tuples of bidirected edges i.e. (X, Y) = X <-> Y. """ # initialize vertices in ADMG super().__init__( vertices=vertices, di_edges=di_edges, bi_edges=bi_edges, **kwargs )
[docs] def markov_pillow(self, vertices, top_order): """ Get the Markov pillow of a set of vertices. That is, the Markov blanket of the vertices given a valid topological order on the graph. :param vertices: iterable of vertex names. :param top_order: a valid topological order. :return: set corresponding to Markov pillow. """ # get the subgraph corresponding to the vertices and nodes prior to them pre = self.pre(vertices, top_order) Gsub = self.subgraph(pre + list(vertices)) # Markov pillow is the Markov blanket (dis(v) union pa(dis(v)) setminus v) # in this subgraph pillow = set() for v in vertices: pillow = pillow.union(Gsub.district(v)) pillow = pillow.union(Gsub.parents(pillow)) return pillow - set(vertices)
[docs] def markov_blanket(self, vertices): """ Get the Markov blanket of a set of vertices. :param vertices: iterable of vertex names. :return: set corresponding to Markov blanket. """ blanket = set() for v in vertices: blanket = blanket.union(self.district(v)) blanket = blanket.union(self.parents(blanket)) return blanket - set(vertices)
@property def fixed(self): """ Returns all fixed nodes in the graph. :return: """ fixed_vertices = [] for v in self.vertices: if self.vertices[v].fixed: fixed_vertices.append(v) return fixed_vertices
[docs] def is_subgraph(self, other): """ Check that this graph is a subgraph of other, meaning it has a subset of edges and nodes of the other. :param other: an object of the ADMG class. :return: boolean indicating whether the statement is True or not. """ if ( set(self.vertices).issubset(set(other.vertices)) and set(self.di_edges).issubset(set(other.di_edges)) and set(self.bi_edges).issubset(set(other.bi_edges)) ): return True return False
[docs] def is_ancestral_subgraph(self, other): """ Check that this graph is an ancestral subgraph of the other. An ancestral subgraph over variables S and intervention b G(S(b)) of a larger graph G(V(b)) is defined as a subgraph, such that ancestors of each node s in S with respect to the graph G(V(b_i)) are contained in S. :param other: an object of the ADMG class. :return: boolean indicating whether the statement is True or not. """ if not self.is_subgraph(other): return False for v in self.vertices: self_parents = set([item.name for item in self.vertices[v].parents]) other_parents = set( [item.name for item in other.vertices[v].parents] ) if self_parents != other_parents: return False return True
[docs] def reachable_closure(self, vertices): """ Obtain reachable closure for a set of vertices. :param vertices: set of vertices to get reachable closure for. :return: set corresponding to the reachable closure, the fixing order for vertices outside of the closure, and the CADMG corresponding to the closure. """ # initialize set of vertices that must still be fixed remaining_vertices = ( set(self.vertices) - set(vertices) - set(v for v in self.vertices if self.vertices[v].fixed) ) fixing_order = [] # keep track of the valid fixing order fixed = True # flag to track that a vertex was successfully fixed in a given pass G = copy.deepcopy(self) # keep iterating over remaining vertices until there are no more or we failed to fix while remaining_vertices and fixed: fixed = False # check if any remaining vertex van be fixed for v in remaining_vertices: # fixability check if len(G.descendants([v]).intersection(G.district(v))) == 1: G.fix([v]) remaining_vertices.remove(v) fixing_order.append(v) fixed = True # flag that we succeeded break # stop the current pass over vertices # compute final reachable closure based on vertices successfully fixed reachable_closure = set(G.vertices) - set( v for v in G.vertices if G.vertices[v].fixed ) # return the reachable closure, the valid order, and the resulting CADMG return reachable_closure, fixing_order, G
[docs] def fixable(self, vertices): """ Check if there exists a valid fixing order and return such an order in the form of a list, else returns an empty list. :param vertices: set of vertices to check fixability for. :return: a boolean indicating whether the set was fixable and a valid fixing order as a stack. """ # keep track of vertices still left to fix (ignoring fixed vertices) # and initialize a fixing order G = copy.deepcopy(self) remaining_vertices = set(vertices) - set(self.fixed) fixing_order = [] fixed = True # flag to check if we fixed a variable on each pass # while we have more vertices to fix, and were able to perform a fix while remaining_vertices and fixed: fixed = False for v in remaining_vertices: # Check if any nodes are reachable via -> AND <-> # by looking at intersection of district and descendants if len(G.descendants([v]).intersection(G.district(v))) == 1: G.fix([v]) remaining_vertices.remove(v) fixing_order.append(v) fixed = True break # if unsuccessful, return failure and # fixing order up until point of failure if not fixed: return False, fixing_order # if fixing vertices was successful, return success # and the fixing order return True, fixing_order
[docs] def subgraph(self, vertices): """ Computes subgraph given a set of vertices. Recomputes districts, since these may change when vertices are removed. :param vertices: iterable of vertices :return: subgraph """ subgraph = super().subgraph(vertices) subgraph._calculate_districts() return subgraph
[docs] def get_intrinsic_sets(self): """ Computes intrinsic sets (and returns the fixing order for each intrinsic set). :returns: list of intrinsic sets and fixing orders used to reach each one """ # create an intrinsic set graph and obtain the intrinsic sets + valid fixing orders leading to them ig = IG(copy.deepcopy(self)) intrinsic_sets = ig.get_intrinsic_sets() fixing_orders = ig.iset_fixing_order_map return intrinsic_sets, fixing_orders
[docs] def get_intrinsic_sets_and_heads(self): """ Computes intrinsic sets mapped to a tuple of heads and tails of that intrinsic set, and fixing orders for each one. :returns: tuple of dict of intrinsic sets to heads and tails, and fixing orders for each intrinsic set """ ig = IG(copy.deepcopy(self)) intrinsic_sets = ig.get_intrinsic_sets() fixing_orders = ig.iset_fixing_order_map heads = [] tails = [] for intrinsic_set in intrinsic_sets: G_sub = self.subgraph(intrinsic_set) head = frozenset( { s for s in intrinsic_set if len(G_sub.vertices[s].children) == 0 } ) tail = frozenset(self.parents(intrinsic_set)) heads.append(head) tails.append(tail) return dict(zip(intrinsic_sets, zip(heads, tails))), fixing_orders
[docs] def maximal_arid_projection(self): """ Get the maximal arid projection that encodes the same conditional independences and Vermas as the original ADMG. This operation is described in Acyclic Linear SEMs obey the Nested Markov property by Shpitser et al 2018. :return: An ADMG corresponding to the maximal arid projection. """ vertices, di_edges, bi_edges = self.vertices, [], [] # keep a cached dictionary of reachable closures and ancestors # for efficiency purposes reachable_closures = {} ancestors = {v: self.ancestors([v]) for v in vertices} # iterate through all vertex pairs for a, b in itertools.combinations(vertices, 2): # decide which reachable closure needs to be computed # and compute it if one vertex is an ancestor of another u, v, rc = None, None, None if a in ancestors[b]: u, v = a, b elif b in ancestors[a]: u, v = b, a # check parent condition and add directed edge if u is a parent of the reachable closure added_edge = False if u: if v not in reachable_closures: reachable_closures[v] = self.reachable_closure([v])[0] rc = reachable_closures[v] if u in self.parents(rc): di_edges.append((u, v)) added_edge = True # if neither are ancestors of each other we need to compute # the reachable closure of set {a, b} and check if it is # bidirected connected if not added_edge: rc, _, cadmg = self.reachable_closure([a, b]) for district in cadmg.districts: if rc <= district: bi_edges.append((a, b)) return ADMG(vertices=vertices, di_edges=di_edges, bi_edges=bi_edges)
[docs] def mb_shielded(self): """ Check if the ADMG is a Markov blanket shielded ADMG. That is, check if two vertices are non-adjacent only when they are absent from each others' Markov blankets. :return: boolean indicating if it is mb-shielded or not. """ # iterate over all pairs of vertices for Vi, Vj in itertools.combinations(self.vertices, 2): # check if the pair is not adjacent if not ( Vi in self.siblings([Vj]) or (Vi, Vj) in self.di_edges or (Vj, Vi) in self.di_edges ): # if one is in the Markov blanket of the other, then it is not mb-shielded if Vi in self.markov_blanket([Vj]) or Vj in self.markov_blanket( [Vi] ): return False return True
[docs] def nonparametric_saturated(self): """ Check if the nested Markov model implied by the ADMG is nonparametric saturated. The following is an implementation of Algorithm 1 in Semiparametric Inference for Causal Effects in Graphical Models with Hidden Variables (Bhattacharya, Nabi & Shpitser 2020) which was shown to be sound and complete for this task. :return: boolean indicating if it is nonparametric saturated or not. """ # iterate over all pairs of vertices for Vi, Vj in itertools.combinations(self.vertices, 2): # check if there is no dense inducing path between Vi and Vj # and return not NPS if either of the checks fail if not ( Vi in self.parents(self.reachable_closure([Vj])[0]) or Vj in self.parents(self.reachable_closure([Vi])[0]) or Vi in self.reachable_closure([Vi, Vj])[2].district(Vj) ): return False return True
[docs] def m_separated(self, X, Y, separating_set=list()): """ Computes m-separation for set `X` and set `Y` given `separating_set` :param X: first vertex set :param Y: second vertex set :param separating_set: separating set :return: boolean result of m-separation """ if type(X) is str: X = [X] if type(Y) is str: Y = [Y] for x, y in itertools.product(X, Y): if not self._m_separated(x, y, separating_set): return False return True
def _m_separated(self, X, Y, separating_set): """ Determine whether `X` and `Y` vertices are m-separated given `separating_set`. Also works on ADMGs which have fixed vertices. :param X: first vertex :param Y: second vertex :param separating_set: list of given vertices :return: boolean result of m-separation """ ancestral_vars = [X, Y] + list(separating_set) # create a new subgraph of the ancestors of vertex1, vertex2, and given vertices ancestral_subgraph = self.subgraph(self.ancestors(ancestral_vars)) augmented_graph_vertices = list(ancestral_subgraph.vertices) augmented_graph_di_edges = list(ancestral_subgraph.di_edges) for V in ancestral_subgraph.vertices.values(): if V.fixed and V.name != X and V.name != Y: separating_set.append(V.name) # if both vertices are fixed, result is undefined if ( ancestral_subgraph.vertices[X].fixed and ancestral_subgraph.vertices[Y].fixed ): raise UndefinedADMGOperation( "{0} and {1} are fixed, so m-separation is undefined.".format( X, Y ) ) for Vi, Vj in itertools.combinations(augmented_graph_vertices, 2): # fixed vertices are connected if ( ancestral_subgraph.vertices[Vi].fixed and ancestral_subgraph.vertices[Vj].fixed ): augmented_graph_di_edges.append((Vi, Vj)) ancestral_subgraph.vertices[Vi].fixed = False ancestral_subgraph.vertices[Vj].fixed = False # ancestral_subgraph._calculate_districts() for Vi, Vj in itertools.combinations(augmented_graph_vertices, 2): # checks for common children between any pairs of vertices children_i = set(ancestral_subgraph.children([Vi])) markov_blanket_Vi = ancestral_subgraph.markov_blanket([Vi]) # connects vertices if one is in the markov blanket of the other vertex, or markov blanket of child of other vertex if ( Vj in markov_blanket_Vi or Vj in ancestral_subgraph.markov_blanket(children_i) ): augmented_graph_di_edges.append((Vi, Vj)) # removes given vertices from the graph for vertex in augmented_graph_vertices[:]: if vertex in separating_set: augmented_graph_vertices.remove(vertex) # removes any edges from the graph that include any of the given vertices for edge in augmented_graph_di_edges[:]: if edge[0] in separating_set or edge[1] in separating_set: augmented_graph_di_edges.remove(edge) # creates a new undirected graph from the updated vertices and edges augmented_graph = UG(augmented_graph_vertices, augmented_graph_di_edges) # checks if vertex 2 is in the block of vertex 1 Y_block = augmented_graph.block(X) return Y not in Y_block
[docs] def latent_projection(self, retained_vertices): """ Computes latent projection. :param retained_vertices: list of vertices to retain after latent projection :returns: Latent projection containing retained vertices """ vertices = self.vertices di_edges = self.di_edges bi_edges = self.bi_edges projected_vertices = list(set(vertices) - set(retained_vertices)) G = self.copy() for vertex in projected_vertices: G = latent_project_single_vertex(vertex=vertex, graph=G) return G
[docs] def canonical_dag(self, cardinality: Union[int, str] = None): """ Computes the canonical DAG from the ADMG by inserting a variable in place of each bidirected edge. For each bidirected edge 'X <-> Y', the inserted variable will take names of the form 'U_XY', the original bidirected edge is removed, and new directed edges 'U_X_Y -> Y' and 'U_X_Y -> X' are inserted. The variable names are in lexicographic order. :params cardinality: The cardinality of the inserted variables :type cardinality: Union[int, str] :returns: A directed acyclic graph :rtype: DAG """ from .dag import DAG if cardinality is None: warnings.warn( "Warning: cardinality of latent variables not set. Please set the cardinality if intending to use ananke.models.discrete functionality." ) G = DAG(vertices=list(self.vertices), di_edges=self.di_edges) for _x, _y in self.bi_edges: x, y = tuple(sorted([_x, _y])) new_var = "U_{}_{}".format(x, y) G.add_vertex(name=new_var, cardinality=cardinality) G.add_diedge(new_var, x) G.add_diedge(new_var, y) for v in self.vertices: if self.vertices[v].fixed: G.vertices[v].fixed = True return G