Source code for cascade.sim

r"""
Simulation functions
"""

from collections import Counter
from collections.abc import Callable, Mapping, Sequence
from enum import IntEnum

import anndata as ad
import networkx as nx
import numpy as np
import pandas as pd
from anndata import AnnData
from loguru import logger
from numpy.typing import ArrayLike
from scipy.sparse import csr_matrix, eye
from tqdm.auto import tqdm

from .data import Targets
from .typing import RandomState
from .utils import get_random_state

Scale = float | Callable[[int, RandomState], Sequence[float]]


ACTIVATION = {
    "ident": lambda x: x,
    "tanh": np.tanh,
}


[docs] class DAGType(IntEnum): r""" Types of directed acyclic graphs Attributes ---------- unif Uniform sf Scale-free """ unif = 0 sf = 1
def _simulator_core(adata: AnnData) -> AnnData: biadj = adata.varp["biadj"] if reuse_norm_in := "norm_in" in adata.varm: norm_in = adata.varm["norm_in"].T else: norm_in = np.empty((2, adata.n_vars)) if reuse_norm_out := "norm_out" in adata.varm: norm_out = adata.varm["norm_out"].T else: norm_out = np.empty((2, adata.n_vars)) snr = adata.var["snr"].to_numpy() act = adata.var["act"].to_numpy() exo = adata.layers["exo"] scale = adata.layers["scale"] topo = list( nx.topological_generations( nx.from_scipy_sparse_array(biadj, create_using=nx.DiGraph) ) ) prev = [] for curr in tqdm(topo, leave=False): if prev: determined = adata.X[:, prev] @ biadj[prev, :][:, curr] else: # root determined = exo[:, curr] if not reuse_norm_in: norm_in[:, curr] = determined.mean(axis=0), determined.std(axis=0) mean_in, std_in = norm_in[:, curr] std_in[std_in == 0] = 1 determined -= mean_in determined /= std_in for i, j in enumerate(curr): determined[:, i] = ACTIVATION[act[j]](determined[:, i]) if not reuse_norm_out: norm_out[:, curr] = determined.mean(axis=0), determined.std(axis=0) mean_out, std_out = norm_out[:, curr] std_out[std_out == 0] = 1 determined -= mean_out determined /= std_out if prev: simulated = determined * snr[curr] + exo[:, curr] else: # root simulated = determined * np.sqrt(snr[curr] ** 2 + 1) # Make same variance adata.X[:, curr] = scale[:, curr] * simulated prev += curr adata.varm["norm_in"] = norm_in.T adata.varm["norm_out"] = norm_out.T return adata # ----------------------------- Public functions -------------------------------
[docs] def generate_dag( n: int, m: int | float, type: DAGType = DAGType.unif, random_state: RandomState = None, ) -> nx.DiGraph: r""" Randomly generate a directed acyclic graph Parameters ---------- n Number of nodes m Target in-degree type Type of DAG to generate, see :class:`DagType` random_state Random state Returns ------- A directed acyclic graph .. note:: - Integer ``m`` is interpreted as a fixed in-degree - Floating ``m`` is interpreted as a fraction of upstream nodes - Nodes are named ``v0``, ``v1``, ..., ``v{n-1}`` """ rnd = get_random_state(random_state) adj = eye(n, dtype=bool, format="lil") for j in range(n): if j and type is DAGType.sf: prev_degs = adj[:j].sum(axis=1).A1 p = prev_degs / prev_degs.sum() else: p = None m_ = m if isinstance(m, int) else round(m * j) i = rnd.choice(j, size=min(m_, j), p=p) adj[i, j] = True adj.setdiag(False) adj = adj.tocsr() adj.data = adj.data * np.sign(rnd.randn(adj.nnz)) dag = nx.from_scipy_sparse_array(adj, create_using=nx.DiGraph) return nx.relabel_nodes(dag, {i: f"v{i}" for i in dag.nodes}, copy=False)
[docs] def simulate_regimes( dag: nx.DiGraph, design: Mapping[Targets, int], interv: Mapping[str, Scale], random_state: RandomState = None, ) -> AnnData: r""" Simulate interventional data based on a causal structure with multiple sets of intervention effect in parallel Parameters ---------- dag A directed acyclic graph representing the causal structure design A mapping from intervention targets to sample numbers interv Intervention scaling factor :math:`\lambda` of each intervention target or sampler function of such (:math:`\lambda = 0` for knockout, :math:`0 \lt \lambda \lt 1` for knockdown, :math:`\lambda \gt 1` for knockup) random_state Random state Returns ------- Simulated dataset .. note:: - The signal-to-noise ratio for each simulated variable should be provided as a node attribute in ``dag`` called ``"snr"``. - The activation function for each simulated variable should be provided as a node attribute in ``dag`` called ``"act"``. """ if not nx.is_directed_acyclic_graph(dag): raise ValueError("Causal structure is not a directed acyclic graph.") rnd = get_random_state(random_state) nodes = pd.Index(dag.nodes) snr = nx.get_node_attributes(dag, "snr") act = nx.get_node_attributes(dag, "act") biadj = csr_matrix( nx.bipartite.biadjacency_matrix(dag, nodes, nodes) ) # anndata does not fully support csr_array yet scale = [] for targets, num in design.items(): s = np.ones((num, len(nodes))) for target in targets: v = interv[target] s[:, nodes.get_loc(target)] = v(num, rnd) if callable(v) else v scale.append(s) scale = np.concatenate(scale) exo = rnd.normal(size=scale.shape) obs = pd.DataFrame( { "knockout": [",".join(nodes[row]) for row in scale == 0], "knockdown": [",".join(nodes[row]) for row in (scale > 0) & (scale < 1)], "knockup": [",".join(nodes[row]) for row in scale > 1], }, index=pd.RangeIndex(scale.shape[0]).astype(str), ) var = pd.DataFrame( { "snr": [snr[node] for node in nodes], "act": [act[node] for node in nodes], }, index=nodes, ) observ_mask = np.all(scale == 1, axis=1) interv_mask = ~observ_mask adata = [] if observ_mask.any(): observ = AnnData( X=np.empty((observ_mask.sum(), scale.shape[1])), obs=obs.loc[observ_mask], var=var, varp={"biadj": biadj}, layers={"scale": scale[observ_mask], "exo": exo[observ_mask]}, ) _simulator_core(observ) adata.append(observ) logger.info("Variables will be normalized by observational samples.") varm = {"norm_in": observ.varm["norm_in"], "norm_out": observ.varm["norm_out"]} else: logger.warning("Variables will be normalized by interventional samples.") varm = {} if interv_mask.any(): interv = AnnData( X=np.empty((interv_mask.sum(), scale.shape[1])), obs=obs.loc[interv_mask], var=var, varm=varm, varp={"biadj": biadj}, layers={"scale": scale[interv_mask], "exo": exo[interv_mask]}, ) _simulator_core(interv) adata.append(interv) return ad.concat(adata, merge="same")
[docs] def simulate_random_regimes( dag: nx.DiGraph, n_obs: int, rate: float, interv: Mapping[str, Scale], random_state: RandomState = None, ) -> AnnData: r""" Simulate interventional data based on a causal structure with random interventions Parameters ---------- dag A directed acyclic graph representing the causal structure n_obs Number of samples rate Interventional rate per node interv Intervention scaling factor :math:`\lambda` of each intervention target or sampler function of such (:math:`\lambda = 0` for knockout, :math:`0 \lt \lambda \lt 1` for knockdown, :math:`\lambda \gt 1` for knockup) act Activation function snr Signal to noise ratio random_state Random state Returns ------- Simulated dataset """ rnd = get_random_state(random_state) n_per_node = round(n_obs * rate) design = [[] for _ in range(n_obs)] for node in dag.nodes: for i in rnd.choice(n_obs, n_per_node): design[i].append(node) design = Counter(map(Targets, design)) return simulate_regimes(dag, design, interv, random_state=rnd)
[docs] def simulate_counterfactual(adata: AnnData, scale: ArrayLike) -> AnnData: r""" Simulate counterfactual outcome of alternative interventions based on an existing simulated dataset Parameters ---------- adata An existing simulated dataset scale Counterfactual interventional scale matrix (same shape as ``adata``) Returns ------- Counterfactual dataset """ adata = adata.copy() adata.layers["scale"] = (scale := np.asarray(scale)) nodes = adata.var_names adata.obs["knockout"] = [",".join(nodes[row]) for row in scale == 0] adata.obs["knockdown"] = [",".join(nodes[row]) for row in (scale > 0) & (scale < 1)] adata.obs["knockup"] = [",".join(nodes[row]) for row in scale > 0] return _simulator_core(adata)