"""
Implementation of DiscreteFator.
API inspired by pgmpy (https://github.com/pgmpy)
"""
import copy
import math
import numpy as np
import sympy as sp
[docs]class BaseDiscreteFactor:
"""
The `inplace=False` argument is to maintain API compatibility with
`pgmpy.factors.discrete.DiscreteFactor`.
"""
def __init__(self, variables, cardinality, values):
pass
[docs] def get_cards_dict(self):
return {k: v for k, v in zip(self.variables, self.cardinality)}
[docs] def marginalize(self, other, inplace=False):
raise NotImplementedError
[docs] def product(self, other, inplace=False):
raise NotImplementedError
[docs] def reduce(self, values, inplace=False):
raise NotImplementedError
[docs] def divide(self, other, inplace=False):
raise NotImplementedError
[docs]class SymDiscreteFactor(BaseDiscreteFactor):
def __init__(self, variables, cardinality, values):
"""
Initializes a symbolic `DiscreteFactor`, which is backed by sympy arrays.
:param variables: Scope of the factor
:type variables: list
:param cardinality: Cardinality of each variable of the factor, in order supplied by
`variables`
:type cardinality: list
:param values: A sympy.Array of variables representing the factor, or sympy.One
:type values: sympy.Array, int
"""
self.variables = list(variables)
self.cardinality = cardinality
if hasattr(values, "reshape"):
self.values = values.reshape(*tuple(self.cardinality))
else:
self.values = values
[docs] def marginalize(self, variables, inplace=False):
if not variables:
return copy.deepcopy(self)
var_indexes = [self.variables.index(var) for var in variables]
index_to_keep = sorted(
set(range(len(self.variables))) - set(var_indexes)
)
new_values = sp.tensorcontraction(self.values, tuple(var_indexes))
return SymDiscreteFactor(
variables=[self.variables[i] for i in index_to_keep],
cardinality=[self.cardinality[i] for i in index_to_keep],
values=new_values,
)
[docs] def product(self, other, inplace=False):
common_vars = sorted(set(self.variables) & set(other.variables))
result = sp.tensorproduct(self.values, other.values)
named_indexes = list(self.variables) + list(other.variables)
if common_vars:
for c in common_vars:
active_indices = [
i for i, x in enumerate(named_indexes) if x == c
]
result = sp.tensordiagonal(result, active_indices)
named_indexes = [i for i in named_indexes if i != c] + [c]
cards = dict(
zip(
self.variables + other.variables,
self.cardinality + other.cardinality,
)
)
final_cardinality = [cards[x] for x in named_indexes]
return SymDiscreteFactor(
variables=named_indexes,
cardinality=final_cardinality,
values=result.reshape(math.prod(final_cardinality)),
)
[docs] def reduce(self, evidence, inplace=False):
slice_ = [slice(None)] * len(self.variables)
var_index_to_del = list()
for v, s in evidence:
var_index = self.variables.index(v)
slice_[var_index] = s
var_index_to_del.append(var_index)
var_index_to_keep = sorted(
set(range(len(self.variables))) - set(var_index_to_del)
)
new_variables = [self.variables[i] for i in var_index_to_keep]
new_cardinality = [self.cardinality[i] for i in var_index_to_keep]
new_values = self.values[tuple(slice_)]
return SymDiscreteFactor(
variables=new_variables,
cardinality=new_cardinality,
values=new_values,
)
[docs] def divide(self, other, inplace=False):
inverted_other = SymDiscreteFactor(
variables=other.variables,
cardinality=other.cardinality,
values=sp.Array(1 / np.array(other.values)),
)
return self.product(inverted_other)
[docs] def get_value(self, **kwargs):
assert set(kwargs.keys()) == set(
self.variables
), "Factor variables do not match specified variables"
ix = [kwargs[var] for var in self.variables]
return self.values[tuple(ix)]
def __eq__(self, other):
if not (
isinstance(self, SymDiscreteFactor)
and isinstance(other, SymDiscreteFactor)
):
return False
elif set(self.variables) != set(other.variables):
return False
else:
if self.variables != other.variables:
other_indexes = [
other.variables.index(var) for var in self.variables
]
other_values = sp.permutedims(other.values, other_indexes)
phi = SymDiscreteFactor(
variables=[other.variables[i] for i in other_indexes],
cardinality=[other.cardinality[i] for i in other_indexes],
values=other_values,
)
else:
phi = copy.deepcopy(other)
if hasattr(self.values, "shape") and hasattr(phi.values, "shape"):
if self.values.shape != phi.values.shape:
return False
else:
if self.values != phi.values:
return False
if self.cardinality != phi.cardinality:
return False
return True
[docs] def subs(self, vals):
new_values = self.values.subs(vals)
return SymDiscreteFactor(
variables=self.variables,
cardinality=self.cardinality,
values=new_values,
)
[docs] def to_pgmpy(self):
import pgmpy
try:
factor = pgmpy.factors.discrete.DiscreteFactor(
variables=self.variables,
cardinality=self.cardinality,
values=np.array(self.values).astype(float64),
)
except TypeError:
raise ValueError(
"There are unsubstituted Sympy variables - cannot convert to pgmpy"
)
return factor
[docs]class SymCPD(SymDiscreteFactor):
def __init__(
self,
variable: str,
variable_card: int,
values: sp.Array,
evidence=None,
evidence_card=None,
):
variables = [variable]
cardinality = [variable_card]
if evidence is not None:
variables.extend(evidence)
cardinality.extend(evidence_card)
self.variable = variable
self.variable_card = variable_card
super().__init__(variables, cardinality, values)
[docs] def get_values(self):
if len(self.variables) > 1:
return self.values.reshape(
self.variable_card, math.prod(self.cardinality[1:])
)
else:
return self.values.reshape(self.variable_card, 1)
[docs] def to_factor(self):
return SymDiscreteFactor(
self.variables.copy(),
self.cardinality.copy(),
copy.deepcopy(self.values),
)
[docs] def to_pgmpy(self):
from pgmpy.factors.discrete import TabularCPD
values = self.get_values()
np_values = np.array(values.tolist()).astype(float)
try:
if len(self.variables) == 1:
factor = TabularCPD(
variable=self.variables[0],
variable_card=self.cardinality[0],
values=np_values,
)
else:
factor = TabularCPD(
variable=self.variables[0],
variable_card=self.cardinality[0],
values=np_values,
evidence=self.variables[1:],
evidence_card=self.cardinality[1:],
)
except TypeError as e:
raise ValueError(
f"There are unsubstituted Sympy variables - cannot convert to pgmpy: {self.values}, {e}"
)
return factor