r"""
Neural network utilities
"""
import os
from collections.abc import Generator
from functools import cached_property
from itertools import product
from math import log, log1p, sqrt
from pathlib import Path
from typing import Any
import networkx as nx
import numpy as np
import pandas as pd
import torch
import torch.distributions as D
import torch.nn.functional as F
from anndata import AnnData
from loguru import logger
from scipy.sparse import issparse
from scipy.sparse.linalg import eigs
from sklearn.linear_model import LinearRegression
from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Parameter, init
from . import __version__
from .data import EPS, _get_size, _get_X
from .utils import count_occurrence, index_len, internal, non_unitary_index
[docs]
def copy_like(source: torch.Tensor, target: torch.Tensor) -> None:
r"""
Copy tensor device, dtype and data from a source tensor to a target tensor
in place
Parameters
----------
source
Source tensor
target
Target tensor
"""
target.data = target.data.to(device=source.device, dtype=source.dtype)
target.data.copy_(source.data)
[docs]
def mean_squared_error(
x: torch.Tensor,
y: torch.Tensor,
dim: int,
keepdim: bool = False,
weight: torch.Tensor | None = None,
) -> torch.Tensor:
r"""
Compute the mean squared error along a specified dimension
Parameters
----------
x
Input tensor x
y
Input tensor y
dim
Dimension along which to compute the error
keepdim
Whether to keep the dimension after reduction
weight
Optional weight tensor
Returns
-------
Mean squared error tensor
"""
if x.size() != y.size():
raise ValueError("Incompatible input sizes")
if weight is None:
weight = x.new_ones(x.size(dim))
elif weight.ndim != 1 or weight.size(0) != x.size(dim):
raise ValueError("Incompatible weight")
else:
weight = weight.numel() * weight / weight.sum()
if dim < -1:
weight = weight.view(-1, *((1,) * (-dim - 1)))
elif dim > -1:
weight = weight.view(-1, *((1,) * (x.ndim - dim - 1)))
return ((x - y).square() * weight).mean(dim, keepdim=keepdim)
[docs]
def gumbel_sigmoid(x: torch.Tensor, tau: float = 1.0) -> torch.Tensor:
r"""
Straight-through Gumbel sigmoid sampler
Parameters
----------
x
Logit tensor
tau
Temperature parameter
Returns
-------
Hard reparameterized samples
"""
noise = torch.empty_like(x).uniform_(EPS, 1 - EPS).logit()
y_soft = torch.sigmoid((x + noise) / tau)
y_hard = (y_soft > 0.5).type_as(y_soft)
return y_hard.detach() - y_soft.detach() + y_soft
[docs]
def multi_trace(m: torch.Tensor) -> torch.Tensor:
r"""
Compute matrix trace with support for multiplex dims
Parameters
----------
m
Matrix of shape (\*m, n_vars, n_vars)
Returns
-------
Matrix trace of shape (\*m,)
"""
return m.diagonal(dim1=-1, dim2=-2).sum(dim=-1)
[docs]
def multi_rbf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r"""
RBF kernel with support for multiplex dims
Parameters
----------
x
Input x of shape (\*m, bs, n_vars)
y
Input y of shape (\*m, bs, n_vars)
Returns
-------
RBF kernel of shape (\*m, bs, bs)
"""
cdist = torch.cdist(x, y) # (*m, bs, bs)
med = cdist.detach().flatten(start_dim=-2).quantile(0.5, dim=-1) # (*m,)
scale = log(x.size(-2)) / (med.square() + EPS)
return (cdist.square() * scale.unsqueeze(-1).unsqueeze(-2)).neg().exp()
[docs]
class Module(torch.nn.Module):
r"""
Abstract module class supporting parameter freezing, decayed / non-decayed
parameter iteration, and cached property clearing
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._frozen = False
@property
def frozen(self) -> bool:
return self._frozen
[docs]
def freeze_params(self) -> None:
r"""
Freeze parameters and turn on evaluation mode
"""
for param in self.parameters(recurse=False):
param.requires_grad_(False)
self.eval()
for name, module in self.named_children():
if isinstance(module, Module):
module.freeze_params()
else:
logger.debug(f"Skipping native torch submodule {name}.")
self.clear_cached_properties()
self._frozen = True
[docs]
def unfreeze_params(self) -> None:
r"""
Unfreeze parameters and turn on training mode
"""
for param in self.parameters(recurse=False):
param.requires_grad_(True)
self.train()
for name, module in self.named_children():
if isinstance(module, Module):
module.unfreeze_params()
else:
logger.debug(f"Skipping native torch submodule {name}.")
self.clear_cached_properties()
self._frozen = False
def _decay_params(self) -> Generator[Parameter, None, None]:
return
yield
[docs]
def decay_params(self) -> Generator[Parameter, None, None]:
r"""
Iterate through weight decayed parameters
"""
for param in self._decay_params():
if param.numel() and param.requires_grad:
yield param
for name, module in self.named_children():
if isinstance(module, Module):
yield from module.decay_params()
else:
logger.debug(f"Skipping native torch submodule {name}.")
[docs]
def regular_params(self) -> Generator[Parameter, None, None]:
r"""
Iterate through non-decayed parameters
"""
decay_params = set(self._decay_params())
for param in self.parameters(recurse=False):
if param in decay_params:
continue
if param.numel() and param.requires_grad:
yield param
for name, module in self.named_children():
if isinstance(module, Module):
yield from module.regular_params()
else:
logger.debug(f"Skipping native torch submodule {name}.")
@internal
def get_extra_state(self) -> dict[str, Any]:
return {"_frozen": self._frozen}
@internal
def set_extra_state(self, state: dict[str, Any]) -> None:
self._frozen = state.pop("_frozen")
[docs]
def clear_cached_properties(self):
r"""
Clear cached properties to allow re-calculation
"""
cls = type(self)
for key in list(self.__dict__.keys()):
if isinstance(getattr(cls, key, None), cached_property):
delattr(self, key)
[docs]
class ModuleList(torch.nn.ModuleList, Module):
r"""
A module list with the :class:`Module` capabilities
"""
[docs]
class MultiLinear(Module):
r"""
Linear layer with support for multi-dims
Parameters
----------
in_features
Input dimensionality
out_features
Output dimensionality
multi_dims
Multiplex dims at the front of input samples
"""
def __init__(
self, in_features: int, out_features: int, multi_dims: tuple[int, ...]
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.multi_dims = multi_dims
self.weight = Parameter(torch.empty(*multi_dims, out_features, in_features))
self.bias = Parameter(torch.empty(*multi_dims, 1, out_features))
self.reset_parameters()
@internal
def reset_parameters(self) -> None:
for s in product(*(range(dim) for dim in self.multi_dims)):
init.kaiming_uniform_(self.weight[s], a=sqrt(5))
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight[s])
bound = 1 / sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias[s], -bound, bound)
[docs]
def forward(
self, x: torch.Tensor, *multi_idx: slice | torch.LongTensor
) -> torch.Tensor:
if multi_idx:
weight, bias = self.weight[*multi_idx], self.bias[*multi_idx]
else:
weight, bias = self.weight, self.bias
return x.matmul(weight.transpose(-1, -2)) + bias
def _decay_params(self) -> Generator[Parameter, None, None]:
yield self.weight
yield from super()._decay_params()
[docs]
class Func(Module):
r"""
Structural equation with covariates
Parameters
----------
in_features
Input dimensionality
cov_features
Covariate dimensionality
out_features
Output dimensionality
hidden_dim
Hidden layer dimensionality
n_layers
Number of hidden layers
multi_dims
Multiplex dims at the front of input samples
dropout
Dropout rate
"""
def __init__(
self,
in_features: int,
cov_features: int,
out_features: int,
hidden_dim: int,
n_layers: int,
multi_dims: tuple[int, ...],
dropout: float,
) -> None:
super().__init__()
self.layers = ModuleList()
for _ in range(n_layers):
self.layers.append(
MultiLinear(in_features + cov_features, hidden_dim, multi_dims)
)
self.layers.append(LeakyReLU(negative_slope=0.2))
if dropout:
self.layers.append(Dropout(p=dropout))
in_features = hidden_dim
self.layers.append(
MultiLinear(in_features + cov_features, out_features, multi_dims)
)
[docs]
def forward(
self, x: torch.Tensor, cov: torch.Tensor, *multi_idx: slice | torch.LongTensor
) -> torch.Tensor:
ptr = x
for layer in self.layers:
if isinstance(layer, MultiLinear):
ptr = layer(torch.cat([ptr, cov], axis=-1), *multi_idx)
else:
ptr = layer(ptr)
return ptr
[docs]
class AttnPool(Module):
r"""
Attention-based pooling layer to combine multiple intervention embeddings
Parameters
----------
emb_dim
Embedding dimensionality
"""
def __init__(self, emb_dim: int) -> None:
super().__init__()
self.norm = 1 / sqrt(emb_dim)
self.query_proj = Linear(emb_dim, emb_dim)
self.key_proj = Linear(emb_dim, emb_dim)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
query = self.query_proj(
x.sum(dim=-2, keepdim=True)
) # ([n_particles,] bs, 1, emb_dim)
key = self.key_proj(x) # ([n_particles,] bs, n_vars, emb_dim)
attention = (
(query * key).sum(dim=-1, keepdim=True) * self.norm
).sigmoid() # ([n_particles,] bs, n_vars, 1)
return (attention * x).sum(dim=-2) # ([n_particles,] bs, emb_dim)
# ------------------------------ Graph scaffold --------------------------------
[docs]
class Scaffold(Module):
r"""
Abstract graph scaffold
Parameters
----------
n_vars
Number of variables in the graph
n_particles
Number of SVGD particles
eidx
Scaffold edge indices of shape (2, n_edges)
tau
Gumbel-sigmoid temperature
"""
def __init__(
self,
n_vars: int,
n_particles: int,
eidx: torch.LongTensor,
tau: float = 10.0,
**kwargs,
) -> None:
super().__init__()
self.n_vars = n_vars
self.n_particles = n_particles
if (eidx < 0).any() or (eidx >= n_vars).any():
raise ValueError("Edge index out of bounds") # pragma: no cover
if (eidx[0] == eidx[1]).any():
raise ValueError("Self-loops are not allowed")
eidx = eidx[:, eidx[0].argsort(stable=True)]
i, j = eidx
k = self.make_k(j)
self.max_indegree = k.max() + 1 if k.numel() else 0
self.register_buffer("idx", torch.stack([i, j, k]))
self.register_buffer("_logit", torch.zeros(self.n_particles, self.n_edges))
self.tau = tau
self.grad_backup = {}
self.kwargs = kwargs
@property
def n_edges(self) -> int:
return self.idx.size(1)
[docs]
@staticmethod
def make_k(j: torch.LongTensor) -> torch.LongTensor:
return torch.as_tensor(
count_occurrence(j.tolist()), dtype=j.dtype, device=j.device
)
[docs]
def construct_sparse_tensor(self, value: torch.Tensor) -> torch.Tensor:
return torch.sparse_coo_tensor(
self.idx[:2],
value,
size=(self.n_vars, self.n_vars, *value.size()[1:]),
)
[docs]
def compute_logit(self) -> torch.Tensor:
raise NotImplementedError # pragma: no cover
@cached_property
def logit(self) -> torch.Tensor:
r"""
Edge logit of shape (n_particles, n_edges)
"""
logit = self._logit if self.frozen else self.compute_logit()
if logit.requires_grad:
logit.retain_grad()
return logit
@cached_property
def prob(self) -> torch.Tensor:
r"""
Edge prob of shape (n_particles, n_edges)
"""
return self.logit.sigmoid()
@cached_property
def adj(self) -> torch.Tensor:
r"""
Sparse adjacency matrix of all particles
"""
return self.construct_sparse_tensor(self.prob.t())
@cached_property
def mean_adj(self) -> torch.Tensor:
r"""
Mean sparse adjacency matrix
"""
return self.construct_sparse_tensor(self.prob.mean(dim=0))
@cached_property
def complete_adj(self) -> torch.Tensor:
r"""
Complete sparse adjacency matrix
"""
return self.construct_sparse_tensor(self._logit.new_ones(self.n_edges, 1))
@property
def mask_map(self) -> torch.LongTensor:
r"""
A reshaped index map of shape (n_vars, max_indegree) where entry (j, k)
has value i, indicating which input gene is in each reshaped position
for each output gene.
"""
mask_map = self.idx.new_zeros(self.n_vars, self.max_indegree) - 1
i, j, k = self.idx
mask_map[j, k] = i
return mask_map
[docs]
def mask_data(
self, x: torch.Tensor, oidx: torch.LongTensor | None = None
) -> torch.Tensor:
if oidx is not None:
mask = torch.isin(self.idx[1], oidx)
i, j, k = self.idx[:, mask]
remap = j.new_empty(self.n_vars)
remap[oidx] = torch.arange(oidx.numel(), device=remap.device)
j = remap[j]
logit = self.logit[:, mask]
n_vars = oidx.numel()
else:
i, j, k = self.idx
logit = self.logit
n_vars = self.n_vars
if x.dim() == 2: # (bs, n_vars)
x = x.unsqueeze(0)
bs = x.size(-2)
if self.training and not self.frozen:
samp = gumbel_sigmoid(logit.unsqueeze(-1).expand(-1, -1, bs), tau=self.tau)
else:
samp = (logit > 0).unsqueeze(-1).expand(-1, -1, bs)
samp_reshape = samp.new_zeros((self.n_particles, n_vars, bs, self.max_indegree))
samp_reshape[:, j, :, k] = samp.transpose(0, 1)
x_reshape = x.new_zeros((x.size(0), n_vars, bs, self.max_indegree))
x_reshape[:, j, :, k] = x[:, :, i].moveaxis(-1, 0)
return x_reshape * samp_reshape
[docs]
def zero_grad(self, set_to_none: bool = True, backup: bool = False) -> None:
if backup:
for name, param in self.named_parameters():
self.grad_backup[name] = (
None if param.grad is None else param.grad.detach().clone()
)
super().zero_grad(set_to_none=set_to_none)
if hasattr(self, "logit") and self.logit.retains_grad:
self.logit.grad = None
[docs]
def accumulate_grad(self) -> None:
for name, param in self.named_parameters():
grad = self.grad_backup.pop(name, None)
if grad is not None:
param.grad.add_(grad)
[docs]
def export_graph(self, edge_attr: str = "weight") -> nx.DiGraph:
self.clear_cached_properties()
i, j, _ = self.idx
i = i.numpy(force=True)
j = j.numpy(force=True)
prob = self.prob.numpy(force=True).T.tolist()
graph = nx.DiGraph()
graph.add_nodes_from(range(self.n_vars)) # Ensure proper node order
graph.add_weighted_edges_from(
((u, v, w) for u, v, w in zip(i, j, prob)), weight=edge_attr
)
return graph
[docs]
def import_graph(self, graph: nx.DiGraph, edge_attr: str = "weight") -> None:
i, j, _ = self.idx
i = i.numpy(force=True)
j = j.numpy(force=True)
attrs = nx.get_edge_attributes(graph, edge_attr)
zeros = [0.0] * self.n_particles
prob = torch.as_tensor(
[attrs.get((u, v), zeros) for u, v in zip(i, j)],
dtype=self.logit.dtype,
device=self.logit.device,
)
self._logit.copy_(prob.logit().T)
self.freeze_params()
def __getitem__(self, index) -> "Scaffold":
result = type(self)(
self.n_vars,
index_len(index, self.n_particles),
self.idx[:2],
self.tau,
**self.kwargs,
)
index = non_unitary_index(index)
copy_like(self._logit[index], result._logit)
result.set_extra_state(self.get_extra_state())
return result
[docs]
def prune(self) -> torch.BoolTensor:
self.clear_cached_properties()
mask = (self.logit > 0).any(dim=0)
i, j = self.idx[:2, mask]
k = self.make_k(j)
self.max_indegree = k.max() + 1 if k.numel() else 0
self.idx = torch.stack([i, j, k])
self._logit = self._logit[:, mask]
return mask
[docs]
def topo_gens(self) -> list[list[torch.LongTensor]]:
gens_list = []
for i in range(self.n_particles):
particle = self[i]
particle.prune()
graph = particle.export_graph()
gens = [torch.as_tensor(gen) for gen in nx.topological_generations(graph)]
gens_list.append(gens)
return gens_list
[docs]
class Edgewise(Scaffold):
r"""
Edgewise parameterized edge logits
Parameters
----------
n_vars
Number of variables in the graph
n_particles
Number of SVGD particles
eidx
Scaffold edge indices of shape (2, n_edges)
tau
Gumbel-sigmoid temperature
"""
INIT_STD: float = 0.1
def __init__(
self,
n_vars: int,
n_particles: int,
eidx: torch.LongTensor,
tau: float = 10.0,
) -> None:
super().__init__(n_vars, n_particles, eidx, tau=tau)
self.edgewise = Parameter(torch.empty(self.n_particles, self.n_edges))
self.reset_parameters()
@internal
def reset_parameters(self) -> None:
init.normal_(self.edgewise, std=self.INIT_STD)
[docs]
def compute_logit(self) -> torch.Tensor:
return self.edgewise.view_as(self.edgewise) # Make non-leaf
def __getitem__(self, index) -> "Edgewise":
result = super().__getitem__(index)
index = non_unitary_index(index)
copy_like(self.edgewise[index], result.edgewise)
return result
[docs]
def prune(self) -> torch.BoolTensor:
mask = super().prune()
edgewise = self.edgewise.data[:, mask]
self.edgewise = Parameter(torch.empty_like(edgewise))
self.edgewise.data.copy_(edgewise)
return mask
def _decay_params(self) -> Generator[Parameter, None, None]:
yield self.edgewise
yield from super()._decay_params()
[docs]
class Bilinear(Scaffold):
r"""
Bilinearly parameterized edge logits
Parameters
----------
n_vars
Number of variables in the graph
n_particles
Number of SVGD particles
eidx
Scaffold edge indices of shape (2, n_edges)
tau
Gumbel-sigmoid temperature
emb_dim
Dimension of the bilinear parameterization
"""
INIT_STD: float = 0.1
def __init__(
self,
n_vars: int,
n_particles: int,
eidx: torch.LongTensor,
tau: float = 10.0,
emb_dim: int = None,
) -> None:
super().__init__(n_vars, n_particles, eidx, tau=tau, emb_dim=emb_dim)
self.emb_dim = emb_dim or round(sqrt(self.n_vars))
self.u = Parameter(torch.empty(self.n_particles, self.n_vars, self.emb_dim))
self.v = Parameter(torch.empty(self.n_particles, self.n_vars, self.emb_dim))
self.reset_parameters()
@internal
def reset_parameters(self) -> None:
init.normal_(self.u, std=self.INIT_STD)
init.normal_(self.v, std=self.INIT_STD)
[docs]
def compute_logit(self) -> torch.Tensor:
i, j = self.idx[:2]
if not self.frozen:
for p in self.u.data:
p.renorm_(2, 0, 1)
for p in self.v.data:
p.renorm_(2, 0, 1)
return (
(self.u[:, i] * self.v[:, j]).sum(dim=-1).type_as(self.u)
) # Revert dtype autocast
def __getitem__(self, index) -> "Bilinear":
result = super().__getitem__(index)
index = non_unitary_index(index)
copy_like(self.u[index], result.u)
copy_like(self.v[index], result.v)
return result
def _decay_params(self) -> Generator[Parameter, None, None]:
yield self.u
yield self.v
yield from super()._decay_params()
# -------------------------------- Graph prior ---------------------------------
[docs]
class Prior(Module):
r"""
Compute unnormalized negative log prior probability of a scaffold graph
Parameters
----------
n_vars
Number of variables in the graph
n_particles
Number of SVGD particles
"""
def __init__(self, n_vars: int, n_particles: int, **kwargs) -> None:
super().__init__()
self.n_vars = n_vars
self.n_particles = n_particles
[docs]
def energy(self, scaffold: Scaffold) -> torch.Tensor:
r"""
Energy function (negative log probability) of a graph scaffold
Parameters
----------
scaffold
Graph scaffold
Returns
-------
Energy of shape (n_particles,)
"""
raise NotImplementedError # pragma: no cover
[docs]
class SparsePrior(Prior):
r"""
Prior that encourages sparsity
Parameters
----------
n_vars
Number of variables in the graph
n_particles
Number of SVGD particles
"""
[docs]
class AcycPrior(Prior):
r"""
Prior that enforces acyclicity constraint
Parameters
----------
n_vars
Number of variables in the graph
n_particles
Number of SVGD particles
"""
[docs]
class L1(SparsePrior):
r"""
L1 penalized log prior probability
Parameters
----------
n_vars
Number of variables in the graph
n_particles
Number of SVGD particles
"""
[docs]
def energy(self, scaffold: Scaffold) -> torch.Tensor:
r"""
L1 sparse energy function (negative log probability) of a graph scaffold
Parameters
----------
scaffold
Graph scaffold
Returns
-------
Sparse energy of shape (n_particles,)
"""
return scaffold.prob.sum(dim=1) / scaffold.n_vars**2
[docs]
class ScaleFree(SparsePrior):
r"""
Scale-free penalized log prior probability
Parameters
----------
n_vars
Number of variables in the graph
n_particles
Number of SVGD particles
"""
[docs]
def energy(self, scaffold: Scaffold) -> torch.Tensor:
r"""
Scale-free energy function (negative log probability) of a graph
scaffold
Parameters
----------
scaffold
Graph scaffold
Returns
-------
Scale-free energy of shape (n_particles,)
"""
n_particles = scaffold.n_particles
n_vars = scaffold.n_vars
prob = scaffold.prob
idx = scaffold.idx
out_degree = prob.new_zeros((n_particles, n_vars))
out_degree.scatter_add_(1, idx[0].unsqueeze(0).expand(n_particles, -1), prob)
return out_degree.log1p().sum(dim=1) / (n_vars * log1p(n_vars))
[docs]
class TrExp(AcycPrior):
r"""
Tr-Exp penalized log prior probability
Parameters
----------
n_vars
Number of variables in the graph
n_particles
Number of SVGD particles
"""
[docs]
def energy(self, scaffold: Scaffold) -> torch.Tensor:
r"""
Tr-Exp acyclic energy function (negative log probability) of a graph
scaffold
Parameters
----------
scaffold
Graph scaffold
Returns
-------
Tr-Exp acyclic energy of shape (n_particles,)
"""
x = scaffold.adj.to_dense().permute(2, 0, 1)
x = x / scaffold.n_vars # Bound Tr-exp - n to e - 1
c = scaffold.complete_adj.to_dense().permute(2, 0, 1)
c = c / scaffold.n_vars
enrg = multi_trace(x.matrix_exp()) - scaffold.n_vars # (n_particles,)
ceil = multi_trace(c.matrix_exp()) - scaffold.n_vars
return enrg / ceil
[docs]
class SpecNorm(AcycPrior):
r"""
Spectral norm penalized log prior probability
Parameters
----------
n_vars
Number of variables in the graph
n_particles
Number of SVGD particles
n_iter
Number of power iterations
"""
def __init__(self, n_vars: int, n_particles: int, n_iter: int = 5) -> None:
super().__init__(n_vars, n_particles, n_iter=n_iter)
self.n_iter = n_iter
self.register_buffer("u1", torch.empty(self.n_vars, self.n_particles))
self.register_buffer("v1", torch.empty(self.n_vars, self.n_particles))
self.register_buffer("u2", torch.empty(self.n_vars, 1))
self.register_buffer("v2", torch.empty(self.n_vars, 1))
self.reset_parameters()
@internal
def reset_parameters(self) -> None:
init.normal_(self.u1)
init.normal_(self.v1)
init.normal_(self.u2)
init.normal_(self.v2)
F.normalize(self.u1, dim=0, out=self.u1)
F.normalize(self.v1, dim=0, out=self.v1)
F.normalize(self.u2, dim=0, out=self.u2)
F.normalize(self.v2, dim=0, out=self.v2)
self.fresh = True
self.limit = None
[docs]
@staticmethod
def mv(
idx: torch.Tensor,
val: torch.Tensor,
vec: torch.Tensor,
) -> torch.Tensor:
r"""
Sparse matrix-vector product with particles in the last dimension
Parameters
----------
idx
Index of the sparse matrix (2, n_edges)
val
Values of the sparse matrix (n_edges, n_particles)
vec
Vector of shape (n_vars, n_particles)
Returns
-------
Matrix-vector product of shape (n_vars, n_particles)
"""
i, j = idx
res = torch.zeros_like(vec)
return res.scatter_add_(0, i.unsqueeze(1).expand_as(val), val * vec[j])
[docs]
@staticmethod
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r"""
Vector dot product with particles in the last dimension
Parameters
----------
x
Vector of shape (n_vars, n_particles)
y
Vector of shape (n_vars, n_particles)
Returns
-------
Dot product of shape (n_particles,)
"""
return (x * y).sum(dim=0)
[docs]
@torch.no_grad()
def power_iteration(
self,
idx: torch.Tensor,
val: torch.Tensor,
u: torch.Tensor,
v: torch.Tensor,
n_iter: int,
) -> torch.Tensor:
val = val.detach()
for _ in range(n_iter):
F.normalize(self.mv(idx.flip(0), val, u) + EPS * u.sum(dim=0), dim=0, out=u)
F.normalize(self.mv(idx, val, v) + EPS * v.sum(dim=0), dim=0, out=v)
self.fresh = False
[docs]
def energy(self, scaffold: Scaffold) -> torch.Tensor:
r"""
Spectral norm acyclic energy function (negative log probability) of a
graph scaffold
Parameters
----------
scaffold
Graph scaffold
Returns
-------
Spectral norm acyclic energy of shape (n_particles,)
"""
idx = scaffold.idx[:2]
val = scaffold.prob.t() # (n_edges, n_particles)
one = val.new_ones(val.size(0), 1)
if self.training:
n_iter = self.n_iter * 5 if self.fresh else self.n_iter
self.power_iteration(idx, val, self.u1, self.v1, n_iter)
self.power_iteration(idx, one, self.u2, self.v2, n_iter)
u1, v1 = self.u1.clone(), self.v1.clone()
u2, v2 = self.u2.clone(), self.v2.clone()
self.limit = (
self.compute_limit(scaffold) if self.limit is None else self.limit
) # NOTE: Cached limit would be incorrect if scaffold changes
enrg = (
self.dot(u1, self.mv(idx, val, v1)) + EPS * u1.sum(dim=0) * v1.sum(dim=0)
) / (self.dot(u1, v1) * self.n_vars) - self.limit
ceil = (
self.dot(u2, self.mv(idx, one, v2)) + EPS * u2.sum(dim=0) * v2.sum(dim=0)
) / (self.dot(u2, v2) * self.n_vars) - self.limit
return enrg.clamp(min=0.0) / ceil.clamp(min=EPS) # (n_particles,)
[docs]
def compute_limit(self, scaffold: Scaffold) -> float:
r"""
Detection limit on the complete DAG
"""
complete = scaffold.complete_adj.to_dense().squeeze(-1).numpy(force=True)
complete_dag = np.triu(complete + complete.T, k=1).clip(max=1.0)
if complete_dag.sum() == 0:
return 0.0
max_eig = eigs(
complete_dag,
k=1,
which="LR",
v0=np.ones(self.n_vars),
return_eigenvectors=False,
)[0]
return max_eig.real.item() / self.n_vars
@internal
def get_extra_state(self) -> dict[str, Any]:
return {"fresh": self.fresh, **super().get_extra_state()}
@internal
def set_extra_state(self, state: dict[str, Any]) -> None:
self.fresh = state.pop("fresh")
super().set_extra_state(state)
[docs]
class LogDet(AcycPrior):
r"""
Log-determinant penalized log prior probability
Parameters
----------
n_vars
Number of variables in the graph
n_particles
Number of SVGD particles
"""
[docs]
def energy(self, scaffold: Scaffold) -> torch.Tensor:
r"""
Log-determinant acyclic energy function (negative log probability) of a
graph scaffold
Parameters
----------
scaffold
Graph scaffold
Returns
-------
Log-determinant acyclic energy of shape (n_particles,)
"""
x = scaffold.adj.to_dense().permute(2, 0, 1)
x = x / scaffold.n_vars
c = scaffold.complete_adj.to_dense().permute(2, 0, 1)
c = c / scaffold.n_vars
eye = torch.eye(scaffold.n_vars, dtype=x.dtype, device=x.device)
enrg = -(eye - x).slogdet()[1] # (n_particles,)
ceil = -(eye - c).slogdet()[1]
return enrg / ceil
# ----------------------------- Latent inference -------------------------------
[docs]
class Latent(Module):
r"""
Interventional latent module
Parameters
----------
n_particles
Number of SVGD particles
latent_dim
Dimensionality of the latent variable
vmap
Variable index mapping with the parent module
:class:`~cascade.core.CausalNetwork`
"""
def __init__(
self,
n_particles: int,
latent_dim: int,
vmap: torch.LongTensor,
**kwargs,
) -> None:
super().__init__()
self.n_particles = n_particles
self.latent_dim = latent_dim
self.register_buffer("vmap", vmap)
self.register_buffer("prior_loc", torch.as_tensor(0.0))
self.register_buffer("prior_scale", torch.as_tensor(1.0))
[docs]
def prior(self) -> D.Normal:
return D.Normal(self.prior_loc, self.prior_scale)
[docs]
def forward(self, r: torch.Tensor) -> D.Normal:
raise NotImplementedError # pragma: no cover
[docs]
class NilLatent(Latent):
r"""
Nil interventional latent module that always outputs the standard normal
Parameters
----------
n_particles
Number of SVGD particles
latent_dim
Dimensionality of the latent variable
vmap
Variable index mapping with the parent module
:class:`~cascade.core.CausalNetwork`
"""
[docs]
def forward(self, r: torch.Tensor) -> D.Normal:
mu = r.new_zeros(self.n_particles, r.size(0), self.latent_dim)
sigma = r.new_ones(self.n_particles, r.size(0), self.latent_dim)
return D.Normal(mu, sigma)
[docs]
class EmbLatent(Latent):
r"""
Intervention latent module encoding from fixed embeddings
Parameters
----------
n_particles
Number of SVGD particles
latent_dim
Dimensionality of the latent variable
vmap
Variable index mapping with the parent module
:class:`~cascade.core.CausalNetwork`
emb
Fixed embedding tensor
"""
def __init__(
self,
n_particles: int,
latent_dim: int,
vmap: torch.LongTensor,
emb: torch.Tensor = None,
) -> None:
if emb is None:
raise ValueError("Embedding tensor must be specified")
super().__init__(n_particles, latent_dim, vmap, emb=emb)
self.register_buffer("emb", emb.to(torch.get_default_dtype()))
self.emb_dim = self.emb.size(1)
self.pool = AttnPool(self.emb_dim)
self.linear = MultiLinear(
in_features=self.emb_dim,
out_features=self.latent_dim * 2,
multi_dims=(self.n_particles,),
)
[docs]
def forward(self, r: torch.Tensor) -> D.Normal:
vi, vj = self.vmap
ptr = (
r[..., vi].unsqueeze(-1) * self.emb[vj]
) # ([n_particles,] bs, n_vars, emb_dim)
ptr = self.pool(ptr) # ([n_particles,] bs, emb_dim)
ptr = self.linear(ptr) # (n_particles, bs, latent_dim * 2)
mu = ptr[..., : self.latent_dim]
sigma = F.softplus(ptr[..., -self.latent_dim :]) + EPS
return D.Normal(mu, sigma)
[docs]
class GCNLatent(Latent):
r"""
Intervention latent module encoding from a graph
Parameters
----------
n_particles
Number of SVGD particles
latent_dim
Dimensionality of the latent variable
vmap
Variable index mapping with the parent module
:class:`~cascade.core.CausalNetwork`
eidx
Graph edge index of shape (2, n_edges)
ewt
Graph edge weight of shape (n_edges,)
emb_dim
Dimensionality of the learnable node embedding
n_layers
Number of graph convolution layers
"""
INIT_STD: float = 0.01
def __init__(
self,
n_particles: int,
latent_dim: int,
vmap: torch.LongTensor,
eidx: torch.LongTensor = None,
ewt: torch.FloatTensor = None,
emb_dim: int = None,
n_layers: int = 1,
) -> None:
if eidx is None:
raise ValueError("Edge index tensor must be specified")
if ewt is None:
raise ValueError("Edge weight tensor must be specified")
super().__init__(
n_particles,
latent_dim,
vmap,
eidx=eidx,
ewt=ewt,
emb_dim=emb_dim,
n_layers=n_layers,
)
if (eidx < 0).any():
raise ValueError("Edge index out of bounds") # pragma: no cover
self.register_buffer("eidx", eidx)
self.register_buffer("ewt", ewt.clone()) # Will normalize in-place
self.n_vars = (
max(
self.eidx.max() if self.eidx.numel() else -1,
self.vmap[1].max() if self.vmap.numel() else -1,
)
+ 1
)
self.emb_dim = emb_dim or round(sqrt(self.n_vars))
self.emb = Parameter(torch.empty(self.n_vars, self.emb_dim))
self.n_layers = n_layers
self.pool = AttnPool(self.emb_dim)
self.linear = MultiLinear(
in_features=self.emb_dim,
out_features=self.latent_dim * 2,
multi_dims=(self.n_particles,),
)
self.normalize_edges()
self.reset_parameters()
[docs]
def vertex_degrees(self, direction: str) -> torch.Tensor:
if direction not in ("in", "out", "both"):
raise ValueError("Unrecognized direction")
degree = self.ewt.new_zeros(self.n_vars)
if direction in ("in", "both"):
degree.scatter_add_(0, self.eidx[1], self.ewt)
if direction in ("out", "both"):
degree.scatter_add_(0, self.eidx[0], self.ewt)
if direction == "both":
loop_mask = self.eidx[0] == self.eidx[1]
degree.scatter_add_(0, self.eidx[0, loop_mask], -self.ewt[loop_mask])
return degree
[docs]
def normalize_edges(self, method: str = "keepvar") -> None:
if method not in ("in", "out", "sym", "keepvar"):
raise ValueError("Unrecognized method")
enorm = self.ewt
if method in ("in", "keepvar", "sym"):
in_degrees = self.vertex_degrees("in")
in_norm = in_degrees.pow(-1 if method == "in" else -0.5)
in_norm[in_norm.isinf()] = 0
enorm = enorm * in_norm[self.eidx[1]]
if method in ("out", "sym"):
out_degrees = self.vertex_degrees("out")
out_norm = out_degrees.pow(-1 if method == "out" else -0.5)
out_norm[out_norm.isinf()] = 0
enorm = enorm * out_norm[self.eidx[0]]
self.ewt.copy_(enorm)
@internal
def reset_parameters(self) -> None:
init.normal_(self.emb, std=self.INIT_STD)
[docs]
def forward(self, r: torch.Tensor) -> D.Normal:
sidx, tidx = self.eidx
emb = self.emb
if not self.frozen:
emb.data.renorm_(2, 0, 1)
for _ in range(self.n_layers):
message = emb[sidx] * self.ewt.unsqueeze(1)
emb = torch.zeros_like(emb)
tidx = tidx.unsqueeze(1).expand_as(message)
emb.scatter_add_(0, tidx, message)
emb = emb.renorm(2, 0, 1)
vi, vj = self.vmap
ptr = r[:, vi].unsqueeze(-1) * emb[vj] # (bs, n_vars, emb_dim)
ptr = self.pool(ptr) # (bs, emb_dim)
ptr = self.linear(ptr)
mu = ptr[..., : self.latent_dim]
sigma = F.softplus(ptr[..., -self.latent_dim :]) + EPS
return D.Normal(mu, sigma)
def _decay_params(self) -> Generator[Parameter, None, None]:
yield self.emb
yield from super()._decay_params()
# ----------------------------- Causal likelihood ------------------------------
[docs]
class Likelihood(Module):
r"""
Abstract class for causal distributions
Parameters
----------
n_vars
Number of variables
"""
def __init__(self, n_vars: int) -> None:
super().__init__()
self.n_vars = n_vars
[docs]
def set_empirical(self, AnnData: AnnData) -> None:
raise NotImplementedError # pragma: no cover
[docs]
def tone(self, x: torch.Tensor, l: torch.Tensor) -> torch.Tensor:
raise NotImplementedError # pragma: no cover
[docs]
def forward(
self,
mean: torch.Tensor,
disp: torch.Tensor,
l: torch.Tensor,
oidx: torch.LongTensor | None = None,
) -> D.Distribution:
raise NotImplementedError # pragma: no cover
[docs]
@staticmethod
def get_mean(est: D.Distribution) -> torch.Tensor:
raise NotImplementedError # pragma: no cover
[docs]
@staticmethod
def get_disp(est: D.Distribution) -> torch.Tensor:
raise NotImplementedError # pragma: no cover
[docs]
def log_prior(self, est: D.Distribution) -> torch.Tensor:
raise NotImplementedError # pragma: no cover
[docs]
class Normal(Likelihood):
r"""
Normal causal distribution
Parameters
----------
n_vars
Number of variables
"""
PRIOR_RATE: float = 0.01
def __init__(self, n_vars: int) -> None:
super().__init__(n_vars)
self.register_buffer(
"prior_shape", torch.ones(self.n_vars) * self.PRIOR_RATE + 1
)
self.register_buffer("prior_rate", torch.ones(self.n_vars) * self.PRIOR_RATE)
self.batch_norm = BatchNorm1d(self.n_vars, eps=1, affine=False)
[docs]
def set_empirical(self, adata: AnnData) -> None:
X = _get_X(adata)
std = torch.as_tensor(
np.sqrt(X.power(2).mean(axis=0).A1 - np.square(X.mean(axis=0).A1))
if issparse(X)
else X.std(axis=0)
)
self.prior_shape.copy_(std * self.prior_rate + 1) # Mode at std
[docs]
def tone(self, x: torch.Tensor, l: torch.Tensor | None = None) -> torch.Tensor:
if x.dim() == 2: # (bs, n_vars)
return self.batch_norm(x)
# (n_particles, bs, n_vars)
return self.batch_norm(x.permute(1, 2, 0)).permute(2, 0, 1)
[docs]
def forward(
self,
mean: torch.Tensor,
disp: torch.Tensor,
l: torch.Tensor,
oidx: torch.LongTensor | None = None,
) -> D.Normal:
return D.Normal(mean, F.softplus(disp) + EPS)
[docs]
@staticmethod
def get_mean(est: D.Normal) -> torch.Tensor:
return est.loc
[docs]
@staticmethod
def get_disp(est: D.Normal) -> torch.Tensor:
return est.scale
[docs]
def log_prior(self, est: D.Normal) -> torch.Tensor:
return D.Gamma(self.prior_shape, self.prior_rate).log_prob(self.get_disp(est))
[docs]
class NegBin(Likelihood):
r"""
Negative binomial causal distribution
Parameters
----------
n_vars
Number of variables
"""
NORM_TARGET: float = 1e4
CAP_RATE: float = 0.75
PRIOR_RATE: float = 0.05
def __init__(self, n_vars: int) -> None:
super().__init__(n_vars)
self.theta_coef: float = None
self.theta_intercept: float = None
self.register_buffer("log_cap", torch.zeros(self.n_vars))
self.register_buffer("prior_rate", torch.as_tensor(self.PRIOR_RATE))
[docs]
def set_empirical(self, adata: AnnData) -> None:
X = _get_X(adata)
size = _get_size(adata)
if size.size == 0:
raise ValueError("Size not configured")
cap = (X / size).max(axis=0)
cap = torch.as_tensor(cap.toarray().ravel() if issparse(cap) else cap)
self.log_cap.copy_(self.CAP_RATE * cap.clamp(min=EPS, max=1.0).log())
bins = pd.qcut(size.ravel(), 5) # Group cells by size
mean_list, theta_list = [], []
for i in bins.categories:
X_ = X[bins == i]
if issparse(X_):
mean = X_.mean(axis=0).A1
var = X_.power(2).mean(axis=0).A1 - np.square(mean)
else:
mean = X_.mean(axis=0)
var = X_.var(axis=0)
mean_list.append(mean)
theta_list.append(
np.clip(np.square(mean) / (var - mean + EPS), a_min=0, a_max=200)
)
mean = np.concatenate(mean_list)
theta = np.concatenate(theta_list)
df = pd.DataFrame({"mean": mean, "mar_theta": theta})
df["log1p_mean"] = np.log1p(df["mean"])
df["bin"] = pd.qcut(df["log1p_mean"], 5) # Group genes by expression
df["res_theta"] = df.groupby("bin", observed=True)[["mar_theta"]].transform(
lambda x: x.quantile(0.9)
) # Use smaller marginal variances as residual variance estimates
lm = LinearRegression().fit(df[["log1p_mean"]], df["res_theta"])
self.theta_coef = lm.coef_[0].item()
self.theta_intercept = lm.intercept_.item()
logger.info(f"Using theta coefficient = {self.theta_coef:.3f}")
logger.info(f"Using theta intercept = {self.theta_intercept:.3f}")
[docs]
def tone(self, x: torch.Tensor, l: torch.Tensor | None = None) -> torch.Tensor:
if l is None:
l = x.sum(dim=-1, keepdim=True)
norm_target = round(
self.NORM_TARGET * x.size(-1) / 2e4
) # Roughly 20k genes -> NORM_TARGET
else:
norm_target = self.NORM_TARGET
if self.training:
sample_target = round(l.quantile(0.5).item())
x_pad = torch.cat([x, l - x.sum(dim=-1, keepdim=True)], dim=-1)
logits = x_pad.log() - l.log()
x_samp = D.Multinomial(
sample_target, logits=logits, validate_args=False
).sample()[
..., :-1
] # Suppress occasional rounding errors
x_norm = x_samp * (norm_target / sample_target)
else:
x_norm = x * (norm_target / l)
return x_norm.log1p()
[docs]
def forward(
self,
mean: torch.Tensor,
disp: torch.Tensor,
l: torch.Tensor,
oidx: torch.LongTensor | None = None,
) -> D.NegativeBinomial:
log_cap = self.log_cap if oidx is None else self.log_cap[oidx]
log_mu = l.log() - F.softplus(-mean) + log_cap
theta = disp.exp() + 1
return D.NegativeBinomial(theta, logits=log_mu - theta.log())
[docs]
@staticmethod
def get_mean(est: D.NegativeBinomial) -> torch.Tensor:
return est.logits.exp() * est.total_count
[docs]
@staticmethod
def get_disp(est: D.NegativeBinomial) -> torch.Tensor:
return est.total_count
[docs]
def log_prior(self, est: D.NegativeBinomial) -> torch.Tensor:
mean = self.get_mean(est).detach()
prior_theta = (self.theta_coef * mean.log1p() + self.theta_intercept).clamp(
min=1.0
)
return D.Gamma(
self.prior_rate * prior_theta + 1,
self.prior_rate,
).log_prob(self.get_disp(est))
@internal
def get_extra_state(self) -> dict[str, Any]:
return {
"theta_coef": self.theta_coef,
"theta_intercept": self.theta_intercept,
**super().get_extra_state(),
}
@internal
def set_extra_state(self, state: dict[str, Any]) -> None:
self.theta_coef = state.pop("theta_coef")
self.theta_intercept = state.pop("theta_intercept")
super().set_extra_state(state)
# -------------------------------- SVGD kernel ---------------------------------
[docs]
class Kernel(Module):
r"""
Abstract class for kernels
"""
def __init__(self) -> None:
super().__init__()
[docs]
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
raise NotImplementedError # pragma: no cover
[docs]
class KroneckerDelta(Kernel):
r"""
Kronecker delta kernel
"""
[docs]
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.eye(
x.size(0), y.size(0), dtype=x.dtype, device=x.device, requires_grad=True
)
[docs]
class RBF(Kernel):
r"""
Radial basis function kernel
"""
[docs]
def forward(
self, x: torch.Tensor, y: torch.Tensor
) -> torch.Tensor: # OK: Verify with the original SVGD paper
x = x.flatten(start_dim=1)
y = y.flatten(start_dim=1)
return multi_rbf(x, y)
# ---------------------------- Intervention design -----------------------------
[docs]
class IntervDesign(Module):
r"""
Intervention design module
Parameters
----------
n_vars
Number of variables
k
Maximal combinatorial order to consider
design_scale_bias
Whether to optimize the intervention scale and bias
mask
Boolean mask that marks variables in the design candidate pool
interv_scale
Intervention scale tensor trained in the discover phase
interv_bias
Intervention bias tensor trained in the discover phase
target_weight
Variable weight when computing target deviation
"""
def __init__(
self,
n_vars: int,
k: int,
design_scale_bias: bool,
mask: torch.BoolTensor,
interv_scale: torch.Tensor,
interv_bias: torch.Tensor,
target_weight: torch.Tensor,
) -> None:
super().__init__()
self.n_vars = n_vars
self.k = k
self.design_scale_bias = design_scale_bias
elem = torch.cat(
[
torch.arange(n_vars)[mask],
torch.as_tensor([n_vars]),
]
) # The last element indicates no intervention
comb = torch.combinations(elem, k, with_replacement=True)
comb = torch.stack(
[
row
for row in comb.unbind()
if (row[row < n_vars].unique(return_counts=True)[1] == 1).all()
]
) # All elements except for the last one must be unique
self.register_buffer("mask", mask)
self.register_buffer("comb", comb)
self.register_buffer("interv_scale", interv_scale.detach())
self.register_buffer("interv_bias", interv_bias.detach())
self.register_buffer("target_weight", target_weight)
self.logits = Parameter(torch.empty(self.comb.size(0)))
self.design_scale = Parameter(torch.empty_like(self.interv_scale))
self.design_bias = Parameter(torch.empty_like(self.interv_bias))
self.reset_parameters()
@internal
def reset_parameters(self) -> None:
init.zeros_(self.logits)
self.design_scale.data.copy_(self.interv_scale)
self.design_bias.data.copy_(self.interv_bias)
[docs]
def simplex2regime(self, simplex: torch.Tensor) -> torch.Tensor:
bs = simplex.size(0)
comb = self.comb.expand(bs, -1, -1) # (bs, n_comb, k)
simplex = simplex.unsqueeze(-1).expand_as(comb) # (bs, n_comb, k)
return (
simplex.new_zeros(bs, self.n_vars + 1, self.k)
.scatter_add_(1, comb, simplex)
.sum(dim=-1)
)[:, :-1]
[docs]
def rsample(self, bs: int) -> torch.Tensor:
simplex = F.gumbel_softmax(self.logits.expand(bs, -1), hard=True)
return self.simplex2regime(simplex)
[docs]
def loss(self, x_est: torch.Tensor, x_tgt: torch.Tensor) -> torch.Tensor:
return mean_squared_error(
x_est, x_tgt.expand_as(x_est), dim=-1, weight=self.target_weight
) # (n_particles, bs)
@property
def scale(self) -> torch.Tensor:
return self.design_scale if self.design_scale_bias else self.interv_scale
@property
def bias(self) -> torch.Tensor:
return self.design_bias if self.design_scale_bias else self.interv_bias
@cached_property
def comb_lists(self) -> list[list[int]]:
return [row[row < self.n_vars].tolist() for row in self.comb.unbind()]
@internal
def get_extra_state(self) -> dict[str, Any]:
return {
"n_vars": self.n_vars,
"k": self.k,
"design_scale_bias": self.design_scale_bias,
**super().get_extra_state(),
}
@internal
def set_extra_state(self, state: dict[str, Any]) -> None:
self.n_vars = state.pop("n_vars")
self.k = state.pop("k")
self.design_scale_bias = state.pop("design_scale_bias")
super().set_extra_state(state)
[docs]
def save(self, fname: os.PathLike) -> None:
r"""
Save the design module to file
Parameters
----------
fname
Path to save the design module (.pt)
"""
fname = Path(fname)
fname.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
"__version__": __version__,
"state_dict": self.state_dict(),
},
fname,
)
[docs]
@classmethod
def load(cls, fname: os.PathLike) -> "IntervDesign":
r"""
Load design module from file
Parameters
----------
fname
Path to load the design module (.pt)
Returns
-------
Loaded design module
"""
loaded = torch.load(fname, weights_only=True)
version = loaded.pop("__version__", "unknown")
if version != __version__:
logger.warning( # pragma: no cover
"Loaded module version {} differs from current version {}.",
version,
__version__,
)
state_dict = loaded.pop("state_dict")
extra_state = state_dict["_extra_state"]
mod = cls(
extra_state["n_vars"],
extra_state["k"],
extra_state["design_scale_bias"],
state_dict["mask"],
state_dict["interv_scale"],
state_dict["interv_bias"],
state_dict["target_weight"],
)
mod.load_state_dict(state_dict)
return mod