Source code for cascade.model

r"""
API entrypoint of the CASCADE model
"""

import os
import re
import shutil
import sys
from itertools import combinations
from logging import WARNING, getLogger
from pathlib import Path
from warnings import filterwarnings

import networkx as nx
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.nn.functional as F
from anndata import AnnData
from loguru import logger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from rich.panel import Panel
from scipy.sparse import csr_matrix
from sklearn.neighbors import NearestNeighbors

from . import __version__, name
from .core import (
    CausalNetwork,
    DiscoverScheduler,
    FitStage,
    LogAdj,
    ModelCheckpoint,
    PredictionWriter,
    PredictMode,
)
from .data import (
    DataModule,
    DynamicPairedDataModule,
    PairedDataModule,
    SimpleDataModule,
    _get_covariate,
    _get_regime,
    _get_size,
    _get_X,
    _set_covariate,
    _set_regime,
    configure_dataset,
    encode_regime,
)
from .nn import IntervDesign
from .typing import Kws, RandomState, SimpleGraph
from .utils import (
    autodevice,
    config,
    console,
    densify,
    get_random_state,
    gp_regression_with_ci,
    internal,
)

filterwarnings("ignore", ".*does not have many workers.*")
getLogger("pytorch_lightning.accelerators.cuda").setLevel(WARNING)
getLogger("pytorch_lightning.utilities.rank_zero").setLevel(WARNING)


[docs] class CASCADE: r""" **C**\ ausality-**A**\ ware **S**\ ingle-**C**\ ell **A**\ daptive **D**\ iscover/**D**\ eduction/**D**\ esign Engine Parameters ---------- vars List of variables to model n_particles Number of SVGD particles n_covariates Dimension of covariates n_layers Number of MLP layers in the structural equations hidden_dim MLP hidden layer dimension in the structural equations latent_dim Dimension of the latent variable (see notes below on how to specify ``latent_data`` depending on ``latent_mod``) dropout Dropout rate beta KL weight of the latent variable scaffold_mod Scaffold graph module, must be one of {"Edgewise", "Bilinear"} sparse_mod Sparse prior module, must be one of {"L1", "ScaleFree"} acyc_mod Acyclic prior module, must be one of {"TrExp", "SpecNorm", "LogDet"} latent_mod Latent module, must be one of {"NilLatent", "EmbLatent", "GCNLatent"} lik_mod Causal likelihood module, must be one of {"Normal", "NegBin"} kernel_mod SVGD kernel module, must be one of {"KroneckerDelta", "RBF"} scaffold_graph Optional scaffold graph latent_data Optional latent data (see notes below on how to specify ``latent_data`` depending on ``latent_mod``) scaffold_kws Keyword arguments to the scaffold graph module, see :class:`~cascade.nn.Edgewise` or :class:`~cascade.nn.Bilinear` for details sparse_kws Keyword arguments to the sparse prior module, see :class:`~cascade.nn.L1` or :class:`~cascade.nn.ScaleFree` for details acyc_kws Keyword arguments to the acyclic prior module, see :class:`~cascade.nn.TrExp`, :class:`~cascade.nn.SpecNorm`, or :class:`~cascade.nn.LogDet` for details latent_kws Keyword arguments to the latent module, see :class:`~cascade.nn.NilLatent`, :class:`~cascade.nn.EmbLatent`, or :class:`~cascade.nn.GCNLatent` for details lik_kws Keyword arguments to the causal likelihood module, see :class:`~cascade.nn.Normal` or :class:`~cascade.nn.NegBin` for details kernel_kws Keyword arguments to the SVGD kernel module, see :class:`~cascade.nn.KroneckerDelta` or :class:`~cascade.nn.RBF` for details random_state Random state log_dir Directory to store tensorboard logs _net **Internal use ONLY** .. note:: The setting for ``latent_dim`` and ``latent_data`` follows rules below: - When ``latent_mod="NilLatent"``, ``latent_data`` must be ``None``. The latent variable is always the standard normal distribution with dimension of ``latent_dim``. - When ``latent_mod="EmbLatent"``, ``latent_data`` must be a :class:`~pandas.DataFrame`, where the index is the variable names and the columns are the embedding dimensions. ``latent_dim`` but must be larger than 0, but does not need to equal the dimension of ``latent_data``, as the latent variable is encoded from the provided embedding with a linear transformation. - When ``latent_mod="GCNLatent"``, ``latent_data`` must be a :class:`~networkx.Graph` or :class:`~networkx.DiGraph`, where the nodes are the variable names and the edges are latent connections. ``latent_dim`` must be larger than 0. The latent variable is encoded from the provided graph with a graph convolutional network. """ def __init__( self, vars: pd.Index | list[str], n_particles: int = 4, n_covariates: int = 0, n_layers: int = 1, hidden_dim: int = 16, latent_dim: int = 16, dropout: float = 0.2, beta: float = 0.1, scaffold_mod: str = "Edgewise", sparse_mod: str = "L1", acyc_mod: str = "SpecNorm", latent_mod: str = "EmbLatent", lik_mod: str = "NegBin", kernel_mod: str = "RBF", scaffold_graph: SimpleGraph | None = None, latent_data: pd.DataFrame | SimpleGraph | None = None, scaffold_kws: Kws = None, sparse_kws: Kws = None, acyc_kws: Kws = None, latent_kws: Kws = None, lik_kws: Kws = None, kernel_kws: Kws = None, random_state: RandomState = 0, log_dir: os.PathLike = ".", _net: CausalNetwork | None = None, ) -> None: self.vars = pd.Index(vars) self.rnd = get_random_state(random_state) self.log_dir = Path(log_dir) self.interv_seen = set() if _net is not None: self.net = _net return scaffold_kws = scaffold_kws or {} if scaffold_graph is None: scaffold_graph = nx.complete_graph(self.vars, create_using=nx.DiGraph) else: scaffold_graph = scaffold_graph.subgraph(self.vars) if not nx.is_directed(scaffold_graph): scaffold_graph = scaffold_graph.to_directed() edgelist = nx.to_pandas_edgelist(scaffold_graph) scaffold_kws["eidx"] = torch.as_tensor( np.stack( [ self.vars.get_indexer(edgelist["source"]), self.vars.get_indexer(edgelist["target"]), ] ) ) latent_kws = latent_kws or {} latent_vars = pd.Index([]) if latent_dim: if latent_mod == "EmbLatent": if not isinstance(latent_data, pd.DataFrame): raise ValueError( f"Latent embedding must be provided for {latent_mod}" ) latent_data = latent_data.reindex(self.vars).dropna() latent_vars = latent_data.index latent_kws["emb"] = torch.as_tensor(latent_data.to_numpy()) elif latent_mod == "GCNLatent": if not isinstance(latent_data, nx.Graph): raise ValueError(f"Latent graph must be provided for {latent_mod}") latent_data = latent_data.subgraph( v for v, deg in latent_data.degree() if deg > 0 ) if not nx.is_directed(latent_data): latent_data = latent_data.to_directed() latent_vars = pd.Index(latent_data.nodes) edgelist = nx.to_pandas_edgelist(latent_data) latent_kws["eidx"] = torch.as_tensor( np.stack( [ latent_vars.get_indexer(edgelist["source"]), latent_vars.get_indexer(edgelist["target"]), ] ) ) latent_kws["ewt"] = ( torch.as_tensor(edgelist["weight"], dtype=torch.get_default_dtype()) if latent_kws["eidx"].size(1) else torch.zeros(0) ) elif latent_mod != "NilLatent": raise ValueError(f"Latent dimension must be non-zero for {latent_mod}") if latent_mod == "NilLatent" and latent_data is not None: raise ValueError("Latent data not accepted for NilLatent") common_vars = self.vars.intersection(latent_vars) latent_kws["vmap"] = torch.as_tensor( np.stack( [ self.vars.get_indexer(common_vars), latent_vars.get_indexer(common_vars), ] ) ) self.manual_seed() self.net = CausalNetwork( n_vars=self.vars.size, n_particles=n_particles, n_covariates=n_covariates, n_layers=n_layers, hidden_dim=hidden_dim, latent_dim=latent_dim, dropout=dropout, beta=beta, scaffold_mod=scaffold_mod, sparse_mod=sparse_mod, acyc_mod=acyc_mod, latent_mod=latent_mod, lik_mod=lik_mod, kernel_mod=kernel_mod, scaffold_kws=scaffold_kws, sparse_kws=sparse_kws, acyc_kws=acyc_kws, latent_kws=latent_kws, lik_kws=lik_kws, kernel_kws=kernel_kws, ) @internal def manual_seed(self) -> None: torch.manual_seed(self.rnd.randint(0, 2**64 - 1, dtype=np.uint64)) @internal def align_vars(self, input: AnnData) -> AnnData: input_vars = input.var_names excess_vars = set(input_vars) - set(self.vars) if excess_vars: logger.warning( f"{len(excess_vars)} variables are not in the " f"`scaffold` and will thus be ignored." ) return input[:, self.vars]
[docs] def export_causal_graph(self, edge_attr: str = "weight") -> nx.DiGraph: r""" Export learned causal graph Parameters ---------- edge_attr Edge attribute name to store edge probabilities Returns ------- Learned causal graph """ digraph = self.net.scaffold.export_graph(edge_attr=edge_attr) return nx.relabel_nodes(digraph, dict(enumerate(self.vars)), copy=False)
[docs] def import_causal_graph( self, digraph: nx.DiGraph, edge_attr: str = "weight" ) -> None: r""" Import pruned causal graph Parameters ---------- digraph Pruned causal graph edge_attr Edge attribute name to read edge probabilities """ digraph = nx.relabel_nodes( digraph, {v: i for i, v in enumerate(self.vars)}, copy=True ) self.net.scaffold.import_graph(digraph, edge_attr=edge_attr)
[docs] def export_causal_map(self) -> pd.DataFrame: r""" Export the reshaped causal map indicating which input gene is in each reshaped position for each output gene, useful for interpreting the result of :meth:`~CASCADE.explain`. Returns ------- Causal map of shape (n_vars, max_indegree) .. note:: Padding positions are labeled as "<pad>" """ return pd.DataFrame( np.append(self.vars.to_numpy(), "<pad>")[ self.net.scaffold.mask_map.numpy(force=True) ], index=self.vars, )
@internal @rank_zero_only def report_banner(self, datamodule: DataModule) -> None: console.print( Panel( f"Training on [hl]{self.vars.size}[/hl] variables " f"with [hl]{self.net.scaffold.n_edges}[/hl] scaffold edges " f"and [hl]{len(datamodule)}[/hl] samples", expand=False, padding=(1, 2), title=name, subtitle=f"v{__version__}", ) ) def _fit( self, datamodule: DataModule, fit_stage: FitStage, accelerator: str, devices: list[int] | str, log_subdir: os.PathLike, opt: str, lr: float, weight_decay: float, accumulate_grad_batches: int, log_adj: LogAdj, **kwargs, ) -> None: self.report_banner(datamodule) tensorboard_logger = TensorBoardLogger( save_dir=self.log_dir / log_subdir, default_hp_metric=False ) trainer = Trainer( accelerator=accelerator, devices=devices, precision=config.PRECISION, logger=tensorboard_logger, log_every_n_steps=config.LOG_STEP_INTERVAL, deterministic=config.DETERMINISTIC, default_root_dir=self.log_dir / log_subdir, **kwargs, ) self.net.opt = opt self.net.lr = lr self.net.weight_decay = weight_decay self.net.accumulate_grad_batches = accumulate_grad_batches self.net.fit_stage = fit_stage self.net.log_adj = log_adj trainer.fit(self.net, datamodule=datamodule) if isinstance(devices, list) and len(devices) > 1: trainer.strategy.barrier() dist.destroy_process_group() if trainer.global_rank > 0: logger.debug("Exiting rank-non-zero process.") sys.exit() logger.debug("Continuing rank-zero process.") def _predict( self, datamodule: DataModule, predict_mode: PredictMode, accelerator: str, devices: list[int] | str, **kwargs, ) -> list[torch.Tensor]: if isinstance(devices, list) and len(devices) > 1: pred_dir = self.log_dir / f"pred-{config.RUN_ID}" callbacks = [PredictionWriter(output_dir=pred_dir)] else: callbacks = None trainer = Trainer( accelerator=accelerator, devices=devices, deterministic=config.DETERMINISTIC, precision=config.PRECISION, callbacks=callbacks, logger=False, enable_checkpointing=False, enable_model_summary=False, **kwargs, ) self.net.predict_mode = predict_mode pred = trainer.predict(self.net, datamodule=datamodule) if isinstance(devices, list) and len(devices) > 1: trainer.strategy.barrier() dist.destroy_process_group() if trainer.global_rank > 0: logger.debug("Exiting rank-non-zero process.") sys.exit() logger.debug("Continuing rank-zero process.") pred = { int(re.search(r"pred(\d+)\.pt", item.name).group(1)): torch.load( item, weights_only=True ) for item in pred_dir.glob("pred*.pt") } ind = { int(re.search(r"ind(\d+)\.pt", item.name).group(1)): torch.load( item, weights_only=True )[0] for item in pred_dir.glob("ind*.pt") } pred = [item for k in sorted(ind) for item in pred[k]] ind = [torch.as_tensor(item) for k in sorted(ind) for item in ind[k]] pred = tuple(torch.cat(items, dim=1) for items in zip(*pred)) argsort = torch.cat(ind).argsort(stable=True) pred = tuple(item[:, argsort] for item in pred) shutil.rmtree(pred_dir) else: pred = tuple(torch.cat(items, dim=1) for items in zip(*pred)) return pred @rank_zero_only def _load_checkpoint(self, path: str) -> None: if os.path.exists(path): console.print(f"Restoring best model: {path}.") self.net = type(self.net).load_from_checkpoint( path, map_location="cpu", design=self.net.design ) else: logger.warning( "No best checkpoint found! Exiting as is." ) # pragma: no cover @rank_zero_only def _update_interv_seen(self, adata: AnnData) -> None: regime = _get_regime(adata) regime_count = np.asarray(regime.sum(axis=0)).ravel() for var, count in zip(self.vars, regime_count): if count: self.interv_seen.add(var) @rank_zero_only def _extrapolate_interv(self) -> None: unseen = set(self.vars) - self.interv_seen if not unseen: logger.info("Skipping extrapolation because all variables intervened.") return if not self.interv_seen: logger.warning("Skipping extrapolation because no variable intervened.") return logger.info( f"Extrapolating scale and bias of {len(unseen)} non-intervened variables " f"from {len(self.interv_seen)} intervened variables." ) seen_mask = self.vars.isin(self.interv_seen) unseen_mask = self.vars.isin(unseen) for param in (self.net.interv_scale, self.net.interv_bias): extrapolate = param.data[:, seen_mask].quantile(0.5, dim=-1, keepdim=True) param.data[:, unseen_mask] = extrapolate.expand(-1, len(unseen))
[docs] def discover( self, adata: AnnData, lam: float = 0.1, alpha: float = 0.5, gamma: float = 1.0, cyc_tol: float = 1e-4, prefit: bool = False, opt: str = "AdamW", lr: float = 5e-3, weight_decay: float = 0.01, accumulate_grad_batches: int = 1, log_adj: LogAdj = LogAdj.mean, batch_size: int = 128, val_check_interval: int = 300, val_frac: float = 0.1, max_epochs: int = 1000, n_devices: int = 1, log_subdir: os.PathLike = "discover", verbose: bool = False, **kwargs, ) -> None: r""" Causal discovery Parameters ---------- adata Input dataset lam Sparse penalty rate (:math:`\eta_\lambda` in paper) alpha Acyclic penalty rate (:math:`\eta_\alpha` in paper) gamma Kernel gradient rate (:math:`\eta_\gamma`) cyc_tol Acyclic violation tolerance prefit Whether to prefit the model on covariates opt Optimizer lr Learning rate weight_decay Weight decay accumulate_grad_batches Number of batches to accumulate before optimizer step log_adj Adjacency matrix logging mode (see :class:`~cascade.core.LogAdj`) batch_size Batch size val_check_interval Validation check interval val_frac Validation fraction max_epochs Maximum number of epochs n_devices Number of GPU devices to use log_subdir Tensorboard log subdirectory (under model-wise ``log_dir``) verbose Whether to print verbose logs **kwargs Additional keyword arguments are passed to :class:`~lightning.pytorch.trainer.trainer.Trainer` """ adata = self.align_vars(adata) self.net.reset_parameters() self.net.prefit = prefit accelerator, granted = autodevice(n_devices) self.manual_seed() datamodule = SimpleDataModule( adata=adata, batch_size=batch_size, pin_memory=accelerator == "gpu", val_frac=val_frac, random_state=self.rnd, ) progress_bar = TQDMProgressBar(refresh_rate=config.PBAR_REFRESH) model_checkpoint = ModelCheckpoint( monitor="val/post_enrg", mode="min", save_top_k=config.CKPT_SAVE_K, verbose=verbose, ) discover_scheduler = DiscoverScheduler( monitor="val/post_enrg", constraint="val/acyc_enrg", tolerance=cyc_tol, lam=lam, alpha=alpha, gamma=gamma, mode="min", min_delta=config.MIN_DELTA, patience=config.PATIENCE, verbose=verbose, ) self.net.lik.set_empirical(adata) self._fit( datamodule=datamodule, fit_stage=FitStage.discover, accelerator=accelerator, devices=granted, log_subdir=log_subdir, opt=opt, lr=lr, weight_decay=weight_decay, accumulate_grad_batches=accumulate_grad_batches, log_adj=log_adj, check_val_every_n_epoch=None, val_check_interval=val_check_interval, max_epochs=max_epochs, callbacks=[progress_bar, discover_scheduler, model_checkpoint], **kwargs, ) self._load_checkpoint(model_checkpoint.best_model_path) self._update_interv_seen(adata) self._extrapolate_interv()
[docs] def tune( self, adata: AnnData, tune_ctfact: bool = False, stratify: str | None = None, opt: str = "AdamW", lr: float = 5e-3, weight_decay: float = 0.01, accumulate_grad_batches: int = 1, log_adj: LogAdj = LogAdj.mean, batch_size: int = 128, val_check_interval: int = 300, val_frac: float = 0.1, max_epochs: int = 1000, n_devices: int = 1, log_subdir: os.PathLike = "tune", verbose: bool = False, **kwargs, ) -> CausalNetwork: r""" Fine-tune structural equations with fixed causal structure Parameters ---------- adata Input dataset tune_ctfact Whether to tune in counterfactual mode, i.e., to use randomly paired samples for counterfactual pairs for tuning. stratify Column name in :attr:`~anndata.AnnData.obs` for stratified random pairing (only relevant when using ``tune_ctfact=True``) opt Optimizer lr Learning rate weight_decay Weight decay accumulate_grad_batches Number of batches to accumulate before optimizer step log_adj Adjacency matrix logging mode (see :class:`~cascade.core.LogAdj`) batch_size Batch size val_check_interval Validation check interval val_frac Validation fraction max_epochs Maximum number of epochs n_devices Number of GPU devices to use log_subdir Tensorboard log subdirectory (under model-wise ``log_dir``) verbose Whether to print verbose logs **kwargs Additional keyword arguments are passed to :class:`~lightning.pytorch.trainer.trainer.Trainer` """ adata = self.align_vars(adata) if not self.net.scaffold.frozen: raise RuntimeError( "Scaffold is not frozen! " "Did you forget to import an acyclified graph?" ) logger.info("Pruning model...") self.net.prune() accelerator, granted = autodevice(n_devices) self.manual_seed() if tune_ctfact: datamodule = DynamicPairedDataModule( pri=adata, sec=adata, stratify=stratify, batch_size=batch_size, pin_memory=accelerator == "gpu", val_frac=val_frac, random_state=self.rnd, ) else: datamodule = SimpleDataModule( adata=adata, batch_size=batch_size, pin_memory=accelerator == "gpu", val_frac=val_frac, random_state=self.rnd, ) progress_bar = TQDMProgressBar(refresh_rate=config.PBAR_REFRESH) model_checkpoint = ModelCheckpoint( monitor="val/lik_enrg", mode="min", save_top_k=config.CKPT_SAVE_K, verbose=verbose, ) earlystopping = EarlyStopping( monitor="val/lik_enrg", mode="min", min_delta=config.MIN_DELTA, patience=config.PATIENCE, verbose=verbose, ) self._fit( datamodule=datamodule, fit_stage=FitStage.ctfact if tune_ctfact else FitStage.tune, accelerator=accelerator, devices=granted, log_subdir=log_subdir, opt=opt, lr=lr, weight_decay=weight_decay, accumulate_grad_batches=accumulate_grad_batches, log_adj=log_adj, check_val_every_n_epoch=None, val_check_interval=val_check_interval, max_epochs=max_epochs, callbacks=[progress_bar, earlystopping, model_checkpoint], **kwargs, ) self._load_checkpoint(model_checkpoint.best_model_path) self._update_interv_seen(adata) self._extrapolate_interv()
[docs] def design( self, source: AnnData, target: AnnData, pool: list[str] | None = None, init: list[str] | None = None, design_size: int = 1, design_scale_bias: bool = False, target_weight: str | None = None, stratify: str | None = None, opt: str = "AdamW", lr: float = 5e-2, weight_decay: float = 0.01, accumulate_grad_batches: int = 1, batch_size: int = 32, val_check_interval: int = 300, val_frac: float = 0.1, max_epochs: int = 1000, n_devices: int = 1, log_subdir: os.PathLike = "design", verbose: bool = False, **kwargs, ) -> tuple[pd.DataFrame, IntervDesign]: r""" Targeted intervention design with continuous optimization Parameters ---------- source Source dataset target Target dataset representing desired outcome pool Optional list of variables as candidate pool init Optional list of variables to initialize the designed interventions design_size Maximal combinatorial order to consider design_scale_bias Whether to optimize the intervention scale and bias target_weight Optional column name in ``target.var`` to weight target variables when computing target deviation stratify Column name in :attr:`~anndata.AnnData.obs` for stratified random pairing opt Optimizer lr Learning rate weight_decay Weight decay accumulate_grad_batches Number of batches to accumulate before optimizer step batch_size Batch size val_check_interval Validation check interval val_frac Validation fraction max_epochs Maximum number of epochs n_devices Number of GPU devices to use log_subdir Tensorboard log subdirectory (under model-wise ``log_dir``) verbose Whether to print verbose logs **kwargs Additional keyword arguments are passed to :class:`~lightning.pytorch.trainer.trainer.Trainer` Returns ------- DataFrame of design scores containing the following column: - "score": Design score Indexed by intervention and sorted by descending scores Intervention design module """ source = self.align_vars(source) target = self.align_vars(target) mask = torch.as_tensor(self.vars.isin(pool or self.vars)) target_weight = ( torch.as_tensor(target.var[target_weight].to_numpy()) if target_weight else torch.ones(target.n_vars) ) self.net.set_design( mask=mask, k=design_size, design_scale_bias=design_scale_bias, target_weight=target_weight, ) init = torch.as_tensor(self.vars.get_indexer(init or [])) if (init < 0).any(): raise ValueError("Invalid init variables") self.net.design.logits.data[ torch.isin(self.net.design.comb, init).any(dim=1) ] = 10.0 self.net.vars = self.vars accelerator, granted = autodevice(n_devices) self.manual_seed() datamodule = DynamicPairedDataModule( pri=source, sec=target, stratify=stratify, batch_size=batch_size, pin_memory=accelerator == "gpu", val_frac=val_frac, random_state=self.rnd, ) progress_bar = TQDMProgressBar(refresh_rate=config.PBAR_REFRESH) model_checkpoint = ModelCheckpoint( monitor="val/nll", mode="min", save_top_k=config.CKPT_SAVE_K, verbose=verbose, ) early_stopping = EarlyStopping( monitor="val/nll", mode="min", min_delta=0.01, patience=config.PATIENCE, verbose=verbose, ) self._fit( datamodule=datamodule, fit_stage=FitStage.design, accelerator=accelerator, devices=granted, log_subdir=log_subdir, opt=opt, lr=lr, weight_decay=weight_decay, accumulate_grad_batches=accumulate_grad_batches, log_adj=LogAdj.none, check_val_every_n_epoch=None, val_check_interval=val_check_interval, max_epochs=max_epochs, callbacks=[progress_bar, early_stopping, model_checkpoint], **kwargs, ) self._load_checkpoint(model_checkpoint.best_model_path) scores = pd.DataFrame( {"score": self.net.design.logits.numpy(force=True)}, index=[",".join(sorted(self.vars[c])) for c in self.net.design.comb_lists], ).sort_values("score", ascending=False, kind="stable") design, self.net.design = self.net.design.cpu(), None return scores, design
[docs] def design_error_curve( self, source: AnnData, target: AnnData, design: IntervDesign, n_steps: int = 500, n_cells: int = 100, confidence_level: float = 0.95, stratify: str | None = None, batch_size: int = 128, n_devices: int = 1, ) -> tuple[pd.DataFrame, float]: r""" Fit an error curve against design scores Parameters ---------- source Source dataset target Target dataset representing desired outcome design Intervention design module from :meth:`~CASCADE.design` n_steps Number of equidistant score steps n_cells Number of cells per design confidence_level Confidence level stratify Column name in :attr:`~anndata.AnnData.obs` for stratified random pairing batch_size Batch size n_devices Number of GPU devices to use Returns ------- DataFrame of design error curve containing the following columns: - "score": Design score - "mse_est": Weighted MSE estimate at equidistant steps - "mse_est_mean": Smoothed weighted MSE estimate - "mse_est_lower": Lower bound of the confidence interval - "mse_est_upper": Upper bound of the confidence interval Indexed by intervention and sorted by descending scores Design score cutoff that covers minimal MSE in the confidence interval """ source = self.align_vars(source) target = self.align_vars(target) logits, comb_lists = design.logits.detach(), design.comb_lists argsort = torch.argsort(logits, stable=True) # Ascending min_logit, max_logit = logits.min(), logits.max() step_size = (max_logit - min_logit) / (n_steps - 1) step_logits = min_logit + step_size * torch.arange(n_steps) step_locs = torch.searchsorted(logits, step_logits, sorter=argsort) step_locs = step_locs.clamp(min=0, max=logits.size(0) - 1).unique() step_locs = argsort[step_locs] n_steps = step_locs.size(0) step_regime = csr_matrix( design.simplex2regime( F.one_hot(step_locs, num_classes=logits.size(0)) ).numpy(force=True) ) # (n_steps, n_vars) source = source[ self.rnd.choice(source.n_obs, n_steps * n_cells, replace=True) ].copy() source.obs_names_make_unique() repeat_idx = np.arange(n_steps).repeat(n_cells) _set_regime(source, step_regime[repeat_idx]) accelerator, granted = autodevice(n_devices) self.manual_seed() datamodule = DynamicPairedDataModule( pri=source, sec=target, stratify=stratify, batch_size=batch_size, pin_memory=accelerator == "gpu", val_frac=0.0, random_state=self.rnd, ) self.net.design = design pred = self._predict( datamodule=datamodule, predict_mode=PredictMode.dsgnerr, accelerator=accelerator, devices=granted, ) self.net.design = None error = ( pd.DataFrame( { "mse_est": pred[0].mean(dim=0).numpy(force=True), "regime": [ ",".join(sorted(self.vars[comb_lists[i]])) for i in step_locs[repeat_idx] ], }, ) .groupby("regime") .mean() .reset_index() ) score = pd.DataFrame( { "score": design.logits.numpy(force=True), "regime": [",".join(sorted(self.vars[c])) for c in design.comb_lists], }, ) curve = ( pd.merge(score, error, how="outer") .set_index("regime") .sort_values("score", ascending=False, kind="stable") ) curve, cutoff = gp_regression_with_ci( curve, x="score", y="mse_est", alpha=confidence_level ) curve["qualified"] = curve["score"] > cutoff return curve, cutoff
[docs] def design_brute_force( self, source: AnnData, target: AnnData, pool: list[str] | None = None, design_size: int = 1, k: int = 30, counterfactual_kws: Kws = None, neighbor_kws: Kws = None, ) -> tuple[pd.DataFrame, AnnData]: r""" Intervention design with brute-force exhaustion Parameters ---------- source Source dataset target Target dataset representing desired outcome pool Optional list of variables as candidate pool design_size Maximal combinatorial order to consider k Number of samples to generate for each design counterfactual_kws Additional keyword arguments passed to :meth:`~CASCADE.counterfactual` neighbor_kws Additional keyword arguments passed to :class:`~sklearn.neighbors.NearestNeighbors` Returns ------- DataFrame of intervention designs, sorted by descending vote counts AnnData object with counterfactual predictions for all designs """ source = self.align_vars(source) target = self.align_vars(target) pool = pool or self.vars search_space = [ ",".join(sorted(c)) for s in range(design_size + 1) for c in combinations(pool, s) ] source_idx = self.rnd.choice(source.n_obs, len(search_space) * k) target_idx = self.rnd.choice(target.n_obs, len(search_space) * k) source = source[source_idx].copy() source.obs["design"] = np.repeat(search_space, k) encode_regime(source, "design", key="design") configure_dataset(source, use_regime="design") # Others kept untouched try: _set_covariate(source, _get_covariate(target)[target_idx]) logger.info("Using target covariates") except ValueError: logger.info("No covariates set") ctfact = self.counterfactual(source, **(counterfactual_kws or {})) self.net.eval() # Otherwise it restores to training mode dtype = torch.get_default_dtype() device = self.net.interv_scale.device ref = self.net.lik.tone( torch.as_tensor(ctfact.X, dtype=dtype, device=device), torch.as_tensor(_get_size(ctfact), dtype=dtype, device=device), ).numpy(force=True) query = self.net.lik.tone( torch.as_tensor(densify(_get_X(target)), dtype=dtype, device=device), torch.as_tensor(_get_size(target), dtype=dtype, device=device), ).numpy(force=True) neighbor = NearestNeighbors(**(neighbor_kws or {})).fit(ref) nni = neighbor.kneighbors(query, return_distance=False) votes = ctfact.obs["design"].iloc[nni.ravel()].value_counts() outcast = [item for item in search_space if item not in votes.index] outcast = pd.Series(0, index=outcast, name="count") design = pd.concat([votes, outcast]) design = design.to_frame().rename(columns={"count": "votes"}) return design, ctfact
[docs] def counterfactual( self, adata: AnnData, batch_size: int = 128, n_devices: int = 1, design: IntervDesign | None = None, fixed_genes: list[str] | None = None, sample: bool = False, ablate_latent: bool = False, ablate_interv: bool = False, ablate_graph: bool = False, ) -> AnnData: r""" Counterfactual deduction for the outcome of alternative interventions for an observed dataset Parameters ---------- adata Input dataset batch_size Batch size n_devices Number of GPU devices to use design Optional intervention design module from :meth:`~CASCADE.design` fixed_genes Optional list of genes to keep their values fixed sample Whether to sample from the counterfactual distribution (True) or use the mean (False) ablate_latent If True, removes the effect of latent variables ablate_interv If True, removes the effect of interventions ablate_graph If True, removes the effect of the causal graph Returns ------- Counterfactual dataset with: - :attr:`~anndata.AnnData.layers`\ ``["X_ctfact"]``: Counterfactual predictions with shape (n_obs, n_vars, n_particles) - :attr:`~anndata.AnnData.X`: Mean values across SVGD particles """ adata = self.align_vars(adata).copy() accelerator, granted = autodevice(n_devices) if fixed_genes is not None: fixed_genes = self.vars.get_indexer(fixed_genes) self.net.fixed_vars = torch.as_tensor(fixed_genes) self.manual_seed() datamodule = SimpleDataModule( adata=adata, batch_size=batch_size, pin_memory=accelerator == "gpu", val_frac=0.0, random_state=self.rnd, ) self.net.design = design self.net.ablate_latent = ablate_latent self.net.ablate_interv = ablate_interv self.net.ablate_graph = ablate_graph pred = self._predict( datamodule=datamodule, predict_mode=PredictMode.ctsamp if sample else PredictMode.ctmean, accelerator=accelerator, devices=granted, ) self.net.design = None adata.layers["X_ctfact"] = pred[2].movedim(0, -1).numpy(force=True) if fixed_genes is not None: fixed_X = densify(_get_X(adata)[:, fixed_genes]) adata.layers["X_ctfact"][:, fixed_genes] = np.atleast_3d(fixed_X) adata.X = adata.layers["X_ctfact"].mean(axis=-1) return adata
[docs] def explain( self, adata: AnnData, ctfact: AnnData, batch_size: int = 128, n_devices: int = 1, design: IntervDesign | None = None, ) -> AnnData: r""" Explain counterfactual outcome with individual components Parameters ---------- adata Factual dataset ctfact Counterfactual prediction from :meth:`~CASCADE.counterfactual` batch_size Batch size n_devices Number of GPU devices to use design Optional intervention design module from :meth:`~CASCADE.design` Returns ------- Dataset with the following explanation components: - :attr:`~anndata.AnnData.layers`\ ``["X_nil"]``: Baseline expression without any effect - :attr:`~anndata.AnnData.layers`\ ``["X_ctrb_i"]``: Contribution from intervention - :attr:`~anndata.AnnData.layers`\ ``["X_ctrb_s"]``: Contribution from covariates - :attr:`~anndata.AnnData.layers`\ ``["X_ctrb_z"]``: Contribution from latent - :attr:`~anndata.AnnData.layers`\ ``["X_ctrb_ptr"]``: Contribution from parents - :attr:`~anndata.AnnData.layers`\ ``["X_tot"]``: Total counterfactual prediction All having shape (n_obs, n_vars, n_particles) """ adata = self.align_vars(adata) ctfact = self.align_vars(ctfact).copy() accelerator, granted = autodevice(n_devices) self.manual_seed() datamodule = PairedDataModule( pri=adata, sec=ctfact, batch_size=batch_size, pin_memory=accelerator == "gpu", val_frac=0.0, random_state=self.rnd, ) self.net.design = design nil, ctrb_i, ctrb_s, ctrb_z, ctrb_ptr, tot = self._predict( datamodule=datamodule, predict_mode=PredictMode.explain, accelerator=accelerator, devices=granted, ) self.net.design = None ctfact.layers["X_nil"] = nil.movedim(0, -1).numpy(force=True) ctfact.layers["X_ctrb_i"] = ctrb_i.movedim(0, -1).numpy(force=True) ctfact.layers["X_ctrb_s"] = ctrb_s.movedim(0, -1).numpy(force=True) ctfact.layers["X_ctrb_z"] = ctrb_z.movedim(0, -1).numpy(force=True) ctfact.layers["X_ctrb_ptr"] = ctrb_ptr.movedim(0, -1).numpy(force=True) ctfact.layers["X_tot"] = tot.movedim(0, -1).numpy(force=True) return ctfact
[docs] def diagnose( self, adata: AnnData, batch_size: int = 128, n_devices: int = 1 ) -> AnnData: r""" Model diagnosis Parameters ---------- adata Input dataset batch_size Batch size n_devices Number of GPU devices to use Returns ------- Dataset with the following diagnostic information: - :attr:`~anndata.AnnData.obsm`\ ``["Z_mean_diag"]``: Latent mean - :attr:`~anndata.AnnData.obsm`\ ``["Z_std_diag"]``: Latent standard deviation - :attr:`~anndata.AnnData.layers`\ ``["X_mean_diag"]``: Reconstructed mean - :attr:`~anndata.AnnData.layers`\ ``["X_std_diag"]``: Reconstructed standard deviation - :attr:`~anndata.AnnData.layers`\ ``["X_disp_diag"]``: Dispersion parameter """ adata = self.align_vars(adata).copy() accelerator, granted = autodevice(n_devices) self.manual_seed() datamodule = SimpleDataModule( adata=adata, batch_size=batch_size, pin_memory=accelerator == "gpu", val_frac=0.0, random_state=self.rnd, ) pred = self._predict( datamodule=datamodule, predict_mode=PredictMode.recon, accelerator=accelerator, devices=granted, ) ( adata.obsm["Z_mean_diag"], adata.obsm["Z_std_diag"], adata.layers["X_mean_diag"], adata.layers["X_std_diag"], adata.layers["X_disp_diag"], ) = (item.movedim(0, -1).numpy(force=True) for item in pred) return adata
[docs] def jacobian( self, adata: AnnData, batch_size: int = 128, n_devices: int = 1 ) -> AnnData: r""" Compute the Jacobian matrix of the model Parameters ---------- adata Input dataset batch_size Batch size n_devices Number of GPU devices to use Returns ------- Dataset with - :attr:`~anndata.AnnData.layers`\ ``["X_jac"]``: The Jacobian matrix with shape (n_obs, n_vars, n_parents, n_particles) """ adata = self.align_vars(adata).copy() accelerator, granted = autodevice(n_devices) self.manual_seed() datamodule = SimpleDataModule( adata=adata, batch_size=batch_size, pin_memory=accelerator == "gpu", val_frac=0.0, random_state=self.rnd, ) pred = self._predict( datamodule=datamodule, predict_mode=PredictMode.jac, accelerator=accelerator, devices=granted, inference_mode=False, # This is necessary because we need to make lightning use # torch.no_grad() rather than torch.inference_mode(), so that we can # re-enable gradient within predict_step() with torch.enable_grad(). ) adata.layers["X_jac"] = ( pred[0].movedim(0, -1).numpy(force=True) ) # (bs, n_vars_out, n_vars_in, n_particles) return adata
[docs] @rank_zero_only def save(self, fname: os.PathLike) -> None: r""" Save model to file Parameters ---------- Path to save the model file (.pt) """ fname = Path(fname) fname.parent.mkdir(parents=True, exist_ok=True) if self.net.design is not None: raise ValueError("Design module should not be present.") torch.save( { "__version__": __version__, "vars": self.vars.to_list(), "interv_seen": list(self.interv_seen), "rnd": tuple( item.tolist() if isinstance(item, np.ndarray) else item for item in self.rnd.get_state() ), "log_dir": self.log_dir.as_posix(), "hparams": dict(self.net.hparams), "state_dict": self.net.state_dict(), }, fname, )
[docs] @classmethod def load(cls, fname: os.PathLike) -> "CASCADE": r""" Load model from file Parameters ---------- fname Path to the saved model file (.pt) Returns ------- Loaded CASCADE model instance """ loaded = torch.load(fname, weights_only=True) version = loaded.pop("__version__", "unknown") if version != __version__: logger.warning( # pragma: no cover "Loaded model version {} differs from current version {}.", version, __version__, ) net = CausalNetwork(**loaded.pop("hparams")) net.load_state_dict(loaded.pop("state_dict")) rnd = get_random_state() rnd.set_state(loaded.pop("rnd")) model = cls( loaded.pop("vars"), random_state=rnd, log_dir=loaded.pop("log_dir"), _net=net, ) model.interv_seen = set(loaded.pop("interv_seen")) return model
[docs] def upgrade_saved_model(fname: os.PathLike) -> None: # pragma: no cover r""" Update the saved model format to be compatible with the latest version of CASCADE Parameters ---------- fname Path to the saved model file (.pt) that needs upgrading """ content = torch.load(fname) rewrite = False hparams = content["hparams"] for k in list(hparams.keys()): if "skeleton" in k: hparams[k.replace("skeleton", "scaffold")] = hparams.pop(k) rewrite = True state_dict = content["state_dict"] for k in list(state_dict.keys()): if "skeleton" in k: state_dict[k.replace("skeleton", "scaffold")] = state_dict.pop(k) rewrite = True if "_fixed_vars" not in state_dict["_extra_state"]: state_dict["_extra_state"]["_fixed_vars"] = torch.empty(0, dtype=torch.long) rewrite = True for k in ("interv_scale", "interv_bias"): if f"_{k}" in state_dict: state_dict[k] = state_dict.pop(f"_{k}") rewrite = True for k in ( "r_design", "_interv_scale_design", "_interv_bias_design", "design_pool", "design_pick", "target_weight", ): if k in state_dict: state_dict.pop(k) rewrite = True if rewrite: torch.save(content, fname)