r"""
Utilities for manipulating causal graphs
"""
from collections.abc import Callable, Collection
from functools import lru_cache, reduce
from itertools import chain
from operator import or_
from typing import Any
import networkx as nx
import numpy as np
import pandas as pd
from anndata import AnnData
from loguru import logger
from scipy.sparse import eye
from tqdm.auto import tqdm
from .typing import SimpleGraph
[docs]
def multiplex(
*graphs: SimpleGraph, edge_attr: str = "weight", na_fill: Any = 0.0
) -> SimpleGraph:
r"""
Combine multiple graphs into a single graph by multiplexing an edge
attribute
Parameters
----------
*graphs
Graphs to be multiplexed
edge_attr
Edge attribute to be multiplexed
na_fill
Value to fill in when an edge is not found in certain graphs
Returns
-------
Multiplexed graph
.. note::
All attributes except ``edge_attr`` will be dropped, because it is
generally not obvious how attributes of nodes and edges missing in
certain multiplexed graphs should be filled.
"""
graph_type = {type(graph) for graph in graphs}
if len(graph_type) > 1:
raise TypeError("Graphs must be of the same type")
graph_type = graph_type.pop()
multiplexed = graph_type()
multiplexed.add_nodes_from(chain(*(graph.nodes for graph in graphs)))
multiplexed.add_edges_from(chain(*(graph.edges for graph in graphs)))
attrs_list = [nx.get_edge_attributes(graph, edge_attr) for graph in graphs]
for e, attr in tqdm(multiplexed.edges.items(), leave=False):
attr[edge_attr] = [attrs.get(e, na_fill) for attrs in attrs_list]
return multiplexed
[docs]
@lru_cache
def multiplex_num(graph: SimpleGraph, edge_attr: str = "weight") -> int:
r"""
Get the number of multiplexed graphs according to an edge attribute
Parameters
----------
graph
A multiplex graph
edge_attr
Multiplexed edge attribute
Returns
-------
Multiplex number
.. note::
Singularly multiplexed graph has num = 1, while non-multiplexed graph
has num = 0.
.. caution::
The cache is **UNSAFE** from inplace graph manipulations.
"""
attrs = nx.get_edge_attributes(graph, edge_attr)
num = {len(val) if isinstance(val, Collection) else 0 for val in attrs.values()}
if len(num) > 1:
raise ValueError("Edge attribute must have consistent size")
if len(num) == 0:
return 1 # Empty graph is treated as not multiplexed
return num.pop()
[docs]
@lru_cache
def demultiplex(graph: SimpleGraph, edge_attr: str = "weight") -> list[SimpleGraph]:
r"""
Split one single graph into multiple graphs by demultiplexing an edge attribute
Parameters
----------
graph
Graph to be demultiplexed
edge_attr
Edge attribute to be demultiplexed
Returns
-------
List of demultiplexed graphs
.. caution::
The cache is **UNSAFE** from inplace graph manipulations.
"""
num = multiplex_num(graph, edge_attr=edge_attr)
if num == 0: # Non-multiplexed
return graph
attrs = nx.get_edge_attributes(graph, edge_attr)
demultiplexed = [graph.copy() for _ in range(num)]
for e, val in tqdm(attrs.items(), leave=False):
for i, g in enumerate(demultiplexed):
g.edges[e][edge_attr] = val[i]
return demultiplexed
[docs]
def map_edges(
graph: SimpleGraph,
edge_attr: str = "weight",
fn: Callable = lambda x: x,
) -> SimpleGraph:
r"""
Map edge attribute by a function
Parameters
----------
graph
Graph to be mapped
edge_attr
Edge attribute to be mapped
fn
Mapping function
Returns
-------
Mapped graph
"""
mapped = graph.copy()
for attr in mapped.edges.values():
attr[edge_attr] = fn(attr[edge_attr])
return mapped
[docs]
def filter_edges(
graph: SimpleGraph,
edge_attr: str = "weight",
cutoff: float | None = None,
n_top: int | None = None,
) -> SimpleGraph:
r"""
Filter graph by an edge attribute
Parameters
----------
graph
Graph to be filtered
edge_attr
Edge attribute used to filter the graph
cutoff
Cutoff value for the edge attribute
n_top
Number of top edges to be kept
Returns
-------
Filtered graph
.. note::
Exactly one of ``cutoff`` and ``n_top`` should be specified.
"""
if (cutoff is None) == (n_top is None):
raise ValueError("Exactly one of cutoff and n_top should be specified")
edge_attr = nx.get_edge_attributes(graph, edge_attr)
if cutoff is None:
cutoff = sorted(edge_attr.values(), reverse=True)[n_top]
filtered = type(graph)(**graph.graph)
filtered.add_nodes_from(graph.nodes.items())
filtered.add_edges_from(
(*e, graph.edges[e]) for e, attr in edge_attr.items() if attr > cutoff
)
return filtered
[docs]
def acyclify(digraph: nx.DiGraph, edge_attr: str = "weight") -> nx.DiGraph:
r"""
Acyclify a directed graph by iteratively removing cycle-inducing edges with
the lowest weights
Parameters
----------
digraph
Directed graph
edge_attr
Attribute key for edge weights
Returns
-------
Acyclic directed graph
.. caution::
This might not be reproducible due to the unstable order of identified
cycles.
"""
if nx.is_directed_acyclic_graph(digraph):
return digraph
dag = digraph.copy()
attrs = nx.get_edge_attributes(dag, edge_attr)
for _ in tqdm(range(dag.number_of_edges()), mininterval=1, leave=False):
try:
cycle = nx.find_cycle(dag)
except nx.NetworkXNoCycle:
break
min_edge = min(cycle, key=lambda e: attrs[e])
dag.remove_edge(*min_edge)
return dag
[docs]
def marginalize(digraph: nx.DiGraph, margin: Collection, max_steps: int) -> nx.DiGraph:
r"""
Marginalize a directed graph by keeping only a subset of observed nodes,
optionally inferring indirect connections up to a maximal number of steps
mediated by latent nodes.
Parameters
----------
digraph
Directed graph
margin
A list of marginal nodes
max_steps
The maximal number of steps to infer indirect edges
Returns
-------
Marginalized graph
.. note::
A new edge attribute "marginalize" is added that indicates whether
the edge is direct or indirect.
"""
if not nx.is_directed(digraph):
raise TypeError("Input graph must be directed")
if max_steps < 0:
raise ValueError("max_steps must be non-negative")
margin = set(margin)
if missing := margin - digraph.nodes:
logger.warning(
f"{len(missing)} nodes are missing from the input graph "
"and will be ignored."
)
marginalized = nx.DiGraph(**digraph.graph)
marginalized.add_nodes_from((u, d) for u, d in digraph.nodes.items() if u in margin)
marginalized.add_edges_from(
[
(u, v, d)
for (u, v), d in digraph.edges.items()
if u in margin and v in margin
],
marginalize="direct",
)
if max_steps == 0:
return marginalized
nodelist = pd.Index(digraph.nodes)
margin, latent = list(marginalized.nodes), list(digraph.nodes - margin)
adj = nx.to_scipy_sparse_array(digraph, nodelist=nodelist, weight=None)
margin_idx = nodelist.get_indexer(margin)
latent_idx = nodelist.get_indexer(latent)
adj_margin_margin = adj[margin_idx, :][:, margin_idx]
adj_margin_latent = adj[margin_idx, :][:, latent_idx]
adj_latent_margin = adj[latent_idx, :][:, margin_idx]
adj_latent_latent = adj[latent_idx, :][:, latent_idx]
accumulator = eye(latent_idx.size)
latent_steps = []
for step in range(max_steps):
accumulator = accumulator @ adj_latent_latent if step else accumulator
latent_steps.append(accumulator)
inf_latent_latent = sum(latent_steps)
inf_margin_margin = adj_margin_latent @ inf_latent_latent @ adj_latent_margin
inf_margin_margin = (
inf_margin_margin.astype(bool).astype(int)
- adj_margin_margin.astype(bool).astype(int)
) > 0
inf = nx.from_scipy_sparse_array(inf_margin_margin, create_using=nx.DiGraph)
nx.relabel_nodes(inf, dict(enumerate(margin)), copy=False)
marginalized.add_edges_from(inf.edges, marginalize="indirect")
return marginalized
[docs]
def assemble_scaffolds(*graphs: SimpleGraph, nodes: list[str] = None) -> nx.DiGraph:
r"""
Assemble multiple scaffold graphs into a heterogeneous one given a specific
node list
Parameters
----------
*graphs
Scaffold graphs to assemble
nodes
Node list
Returns
-------
Assembled scaffold graph
"""
graphs = [
marginalize(
graph.to_directed(),
nodes,
max_steps=graph.graph.get("marginalize_steps", 0),
)
for graph in graphs
]
for graph in graphs:
nx.set_edge_attributes(
graph, graph.graph.get("data_source", "unknown"), name="data_source"
)
nx.set_edge_attributes(
graph, graph.graph.get("evidence_type", "unknown"), name="evidence_type"
)
assembled = nx.compose_all(graphs)
assembled.add_nodes_from(nodes)
assembled.remove_edges_from([(v, v) for v in assembled.nodes])
return assembled
[docs]
def node_stats(graph: nx.DiGraph) -> pd.DataFrame:
r"""
Get node statistics of a graph
Parameters
----------
graph
Graph
Returns
-------
Node statistics
"""
topo_gens = (
{g: i for i, gen in enumerate(nx.topological_generations(graph)) for g in gen}
if nx.is_directed_acyclic_graph(graph)
else {}
)
rows = [
{
"node": n,
"in_degree": graph.in_degree(n),
"out_degree": graph.out_degree(n),
"n_ancestors": len(nx.ancestors(graph, n)),
"n_descendants": len(nx.descendants(graph, n)),
"topo_gen": topo_gens.get(n, np.nan),
}
for n in graph.nodes
]
return pd.DataFrame(rows).set_index("node")
[docs]
def annotate_explanation(
digraph: nx.DiGraph, ctfact: AnnData, causal_map: pd.DataFrame, cutoff: float = 0.1
) -> nx.DiGraph:
r"""
Annotate counterfactual explanation to the causal graph
Parameters
----------
digraph
Causal graph (from :meth:`~cascade.model.CASCADE.export_causal_graph`)
ctfact
Dataset with counterfactual explanation (from
:meth:`~cascade.model.CASCADE.explain`)
causal_map
Causal map (from :meth:`~cascade.model.CASCADE.export_causal_map`)
cutoff
Minimal cutoff of absolute total change for a gene to be annotated with
contributions (small changes cannot be reliably annotated)
Returns
-------
Annotated causal graph
.. tip::
It is strongly recommended to limit cells to a single perturbation for
use as ``ctfact``.
"""
nil = ctfact.layers["X_nil"].mean(axis=(0, -1)).astype(float)
ctrb_i = ctfact.layers["X_ctrb_i"].mean(axis=(0, -1)).astype(float)
ctrb_s = ctfact.layers["X_ctrb_s"].mean(axis=(0, -1)).astype(float)
ctrb_z = ctfact.layers["X_ctrb_z"].mean(axis=(0, -1)).astype(float)
ctrb_ptr = ctfact.layers["X_ctrb_ptr"].mean(axis=(0, -1)).astype(float)
tot = ctfact.layers["X_tot"].mean(axis=(0, -1)).astype(float)
nil = pd.Series(nil, index=ctfact.var_names)
ctrb_i = pd.Series(ctrb_i, index=ctfact.var_names)
ctrb_s = pd.Series(ctrb_s, index=ctfact.var_names)
ctrb_z = pd.Series(ctrb_z, index=ctfact.var_names)
ctrb_ptr = pd.DataFrame(ctrb_ptr, index=ctfact.var_names)
tot = pd.Series(tot, index=ctfact.var_names)
diff_tot = tot - nil
diff_i = ctrb_i - nil
diff_s = ctrb_s - nil
diff_z = ctrb_z - nil
diff_ptr = ctrb_ptr.sub(nil, axis="index")
ann_tot = pd.DataFrame(
{
"sign_tot": np.sign(diff_tot),
"diff_tot": diff_tot,
}
)
mask = diff_tot.abs() > cutoff # Only annotate contribution for large effects
diff_tot = diff_tot.loc[mask]
diff_i = diff_i.loc[mask]
diff_s = diff_s.loc[mask]
diff_z = diff_z.loc[mask]
diff_ptr = diff_ptr.loc[mask]
causal_map = causal_map.loc[mask]
ann_sz = pd.DataFrame(
{
"sign_i": np.sign(diff_i),
"sign_s": np.sign(diff_s),
"sign_z": np.sign(diff_z),
"frac_i": np.clip(diff_i / diff_tot, 0, 1),
"frac_s": np.clip(diff_s / diff_tot, 0, 1),
"frac_z": np.clip(diff_z / diff_tot, 0, 1),
}
)
sign_ptr: pd.DataFrame = np.sign(diff_ptr)
sign_ptr = sign_ptr.reset_index().melt(id_vars="index", value_name="sign")
frac_ptr = np.clip(diff_ptr.div(diff_tot, axis="index"), 0, 1)
frac_ptr = frac_ptr.reset_index().melt(id_vars="index", value_name="frac")
causal_map = causal_map.reset_index().melt(id_vars="index", value_name="parent")
ann_ptr = reduce(
pd.merge, [causal_map.query("parent != '<pad>'"), frac_ptr, sign_ptr]
)
ann_ptr_sum = ann_ptr.groupby("index")["frac"].sum().to_frame()
digraph = digraph.copy()
digraph.graph["cutoff"] = cutoff
for idx, row in ann_tot.iterrows():
attr = digraph.nodes[idx]
attr["sign_tot"] = "down" if row["sign_tot"] == -1 else "up"
attr["diff_tot"] = row["diff_tot"]
for idx, row in ann_sz.iterrows():
attr = digraph.nodes[idx]
attr["sign_i"] = "down" if row["sign_i"] == -1 else "up"
attr["sign_s"] = "down" if row["sign_s"] == -1 else "up"
attr["sign_z"] = "down" if row["sign_z"] == -1 else "up"
attr["frac_i"] = row["frac_i"]
attr["frac_s"] = row["frac_s"]
attr["frac_z"] = row["frac_z"]
for idx, row in ann_ptr_sum.iterrows():
attr = digraph.nodes[idx]
attr["frac_ptr"] = row["frac"]
for _, row in ann_ptr.iterrows():
u, v = row["parent"], row["index"]
attr = digraph.edges[u, v]
attr["sign"] = (
"represses"
if row["sign"] * ann_tot.loc[u, "sign_tot"] == -1
else "activates"
)
attr["frac"] = row["frac"]
return digraph
[docs]
def core_explanation_graph(
annotated: nx.DiGraph,
leaves: list[str],
min_frac_ptr: float = 0.05,
depth_limit: int | None = None,
) -> nx.DiGraph:
r"""
Extract the core explanation graph from an annotated causal graph that
explains the predicted change in a list of leaf nodes
Parameters
----------
annotated
Annotated causal graph (from :func:`annotate_explanation`)
leaves
List of leaf nodes
min_frac_ptr
Minimal explained fraction for an edge to be considered
depth_limit
Depth limit for the breadth-first search
Returns
-------
Core explanation graph
"""
cutoff = annotated.graph["cutoff"]
annotated = annotated.subgraph(
u for u, d in annotated.nodes.items() if abs(d.get("diff_tot", 0)) > cutoff
)
annotated = filter_edges(annotated, edge_attr="frac", cutoff=min_frac_ptr)
bfs_trees = [
nx.bfs_tree(annotated, leaf, reverse=True, depth_limit=depth_limit)
for leaf in leaves
if leaf in annotated.nodes
]
bfs_nodes = reduce(or_, [tree.nodes for tree in bfs_trees])
return annotated.subgraph(bfs_nodes)
[docs]
def prep_cytoscape(
annotated: nx.DiGraph, scaffold: nx.DiGraph, perts: list[str], leaves: list[str]
) -> nx.DiGraph:
r"""
Prepare a graph for visualization in `Cytoscape <https://cytoscape.org>`_
Parameters
----------
annotated
Annotated causal graph (from :func:`annotate_explanation`)
scaffold
Scaffold graph (from :meth:`~cascade.model.CASCADE.export_causal_graph`)
perts
List of perturbed nodes
leaves
List of leaf nodes to explain
Returns
-------
Cytoscape-ready graph
.. tip::
Please visit TODO to obtain a template Cytoscape file containing
corresponding styles.
"""
annotated = annotated.copy()
for u, d in annotated.nodes.items():
d["frac_i"] = d.get("frac_i", 0)
d["frac_s"] = d.get("frac_s", 0)
d["frac_z"] = d.get("frac_z", 0)
d["frac_ptr"] = d.get("frac_ptr", 0)
if u in perts:
d["role"] = "interv"
elif u in leaves:
d["role"] = "leaf"
else:
d["role"] = "med"
for e, d in annotated.edges.items():
scf = scaffold.edges[e]
d["frac"] = np.clip(d.get("frac", 0), 0, 1)
d["evidence_type"] = scf.get("evidence_type", "unknown")
return annotated