r"""
Core pytorch lightning module and training callbacks for the CASCADE model
"""
import os
from collections.abc import Iterable
from enum import IntEnum
from itertools import zip_longest
from pathlib import Path
from typing import Any
import torch
import torch.distributions as D
from loguru import logger
from pytorch_lightning import LightningModule, Trainer, callbacks
from torch.nn import Parameter, init
from . import nn
from .data import EPS
from .nn import (
AcycPrior,
Func,
IntervDesign,
Kernel,
Latent,
Likelihood,
Module,
MultiLinear,
Scaffold,
SparsePrior,
copy_like,
)
from .typing import Kws
from .utils import console, internal
[docs]
class FitStage(IntEnum):
r"""
Model fitting stage
Attributes
----------
discover
Causal discover stage
tune
Self-reconstruction model tuning stage (after graph acyclification)
ctfact
Counterfactual model tuning stage (after graph acyclification)
design
Intervention design stage
"""
discover = 0
tune = 1
ctfact = 2
design = 3
[docs]
class PredictMode(IntEnum):
r"""
Model prediction mode
Attributes
----------
recon
Predict self-reconstruction
jac
Compute the Jacobian matrix
explain
Explain counterfactual prediction by components
ctmean
Predict counterfactual state with the mean
ctsamp
Predict counterfactual state with random sampling
"""
recon = 0
jac = 1
explain = 2
dsgnerr = 3
ctmean = 4
ctsamp = 5
[docs]
class LogAdj(IntEnum):
r"""
Logging mode of the adjacency matrix in tensorboard
Attributes
----------
none
Disable adjacency matrix logging
mean
Only log the mean adjacency matrix across SVGD particles
particles
Log adjacency matrices of individual SVGD particles
both
Log both the mean and individual adjacency matrices
"""
none = 0
mean = 1
particles = 2
both = 3
[docs]
class CausalNetwork(LightningModule, Module):
r"""
Causal discovery neural network
Parameters
----------
n_vars
Number of variables to model
n_particles
Number of SVGD particles
n_covariates
Dimension of covariates
n_layers
Number of MLP layers in the structural equations
hidden_dim
MLP hidden layer dimension in the structural equations
latent_dim
Dimension of the latent variable
dropout
Dropout rate
beta
KL weight of the latent variable
scaffold_mod
Scaffold graph module, must be one of {"Edgewise", "Bilinear"}
sparse_mod
Sparse prior module, must be one of {"L1", "ScaleFree"}
acyc_mod
Acyclic prior module, must be one of {"TrExp", "SpecNorm", "LogDet"}
latent_mod
Latent module, must be one of {"NilLatent", "EmbLatent", "GCNLatent"}
lik_mod
Causal likelihood module, must be one of {"Normal", "NegBin"}
kernel_mod
SVGD kernel module, must be one of {"KroneckerDelta", "RBF"}
scaffold_kws
Keyword arguments to the scaffold graph module, see
:class:`~cascade.nn.Edgewise` or :class:`~cascade.nn.Bilinear` for
details
sparse_kws
Keyword arguments to the sparse prior module, see
:class:`~cascade.nn.L1` or :class:`~cascade.nn.ScaleFree` for details
acyc_kws
Keyword arguments to the acyclic prior module, see
:class:`~cascade.nn.TrExp`, :class:`~cascade.nn.SpecNorm`, or
:class:`~cascade.nn.LogDet` for details
latent_kws
Keyword arguments to the latent module, see
:class:`~cascade.nn.NilLatent`, :class:`~cascade.nn.EmbLatent`, or
:class:`~cascade.nn.GCNLatent` for details
lik_kws
Keyword arguments to the causal likelihood module, see
:class:`~cascade.nn.Normal` or :class:`~cascade.nn.NegBin` for details
kernel_kws
Keyword arguments to the SVGD kernel module, see
:class:`~cascade.nn.KroneckerDelta` or :class:`~cascade.nn.RBF` for
details
design
Optional intervention design module, see
:class:`~cascade.nn.IntervDesign` for details
"""
EXP_AVG: float = 0.5
def __init__(
self,
n_vars: int,
n_particles: int,
n_covariates: int,
n_layers: int,
hidden_dim: int,
latent_dim: int,
dropout: float,
beta: float,
scaffold_mod: str,
sparse_mod: str,
acyc_mod: str,
latent_mod: str,
lik_mod: str,
kernel_mod: str,
scaffold_kws: Kws = None,
sparse_kws: Kws = None,
acyc_kws: Kws = None,
latent_kws: Kws = None,
lik_kws: Kws = None,
kernel_kws: Kws = None,
design: IntervDesign | None = None,
) -> None:
super().__init__()
self.save_hyperparameters(ignore="design")
self.cache = {}
self.n_vars = n_vars
self.n_particles = n_particles
self.n_covariates = n_covariates
self.n_layers = n_layers
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.dropout = dropout
self.beta = beta
self.lam = 0.0
self.alpha = 0.0
self.gamma = 0.0
# Submodules
self.scaffold: Scaffold = self._get_mod(
mod=scaffold_mod,
mod_type=Scaffold,
n_vars=n_vars,
n_particles=n_particles,
**(scaffold_kws or {}),
)
self.sparse: SparsePrior = self._get_mod(
mod=sparse_mod,
mod_type=SparsePrior,
n_vars=n_vars,
n_particles=n_particles,
**(sparse_kws or {}),
)
self.acyc: AcycPrior = self._get_mod(
mod=acyc_mod,
mod_type=AcycPrior,
n_vars=n_vars,
n_particles=n_particles,
**(acyc_kws or {}),
)
self.kernel: Kernel = self._get_mod(
mod=kernel_mod,
mod_type=Kernel,
**(kernel_kws or {}),
)
self.latent: Latent = self._get_mod(
mod=latent_mod,
mod_type=Latent,
n_particles=n_particles,
latent_dim=latent_dim,
**(latent_kws or {}),
)
self.lik: Likelihood = self._get_mod(
mod=lik_mod,
mod_type=Likelihood,
n_vars=n_vars,
**(lik_kws or {}),
)
self.func: Func = self._get_mod(
mod="Func",
mod_type=Func,
in_features=self.scaffold.max_indegree,
cov_features=self.latent_dim + self.n_covariates,
out_features=2,
hidden_dim=hidden_dim,
n_layers=n_layers,
multi_dims=(n_particles, n_vars),
dropout=dropout,
)
self.func.layers[-1].weight.data[..., 1, :].fill_(0.0) # Stabilize disp
self.design: IntervDesign | None = design
# Parameters
self.interv_scale = Parameter(torch.empty(n_particles, n_vars))
self.interv_bias = Parameter(torch.empty(n_particles, n_vars))
# Buffers
n_edges = self.scaffold.n_edges
self.register_buffer("lik_grad_avg", torch.zeros(n_particles, n_edges))
self.register_buffer("sparse_grad_avg", torch.zeros(n_particles, n_edges))
self.register_buffer("acyc_grad_avg", torch.zeros(n_particles, n_edges))
self.register_buffer("kernel_grad_avg", torch.zeros(n_particles, n_edges))
self.automatic_optimization = False
self.reset_parameters()
self.reset_properties()
@staticmethod
def _get_mod(mod: str, mod_type: type, *args, **kwargs) -> Module:
mod = getattr(nn, mod)
if issubclass(mod, mod_type):
return mod(*args, **kwargs)
raise TypeError(f"Unrecognized {mod_type.__name__} module")
[docs]
def reset_parameters(self) -> None:
init.zeros_(self.interv_scale)
init.zeros_(self.interv_bias)
[docs]
def reset_properties(self) -> None:
# Optimization
self.opt = None
self.lr = None
self.weight_decay = None
self.accumulate_grad_batches = None
# Model flags
self.fit_stage = None
self.predict_mode = None
self.prefit = None
# Prediction & logging
self.fixed_vars = None
self.ablate_latent = None
self.ablate_interv = None
self.ablate_graph = None
self.log_adj = None
self.vars = None
[docs]
def set_design(
self,
mask: torch.BoolTensor,
k: int,
design_scale_bias: bool,
target_weight: torch.Tensor,
) -> None:
r"""
Set the design module
Parameters
----------
mask
Boolean mask that marks variables in the design candidate pool
k
Maximal combination order to consider
design_scale_bias
Whether to optimize interventional scale and bias as well
target_weight
Variable weights for computing target deviation
"""
self.design = self._get_mod(
mod="IntervDesign",
mod_type=IntervDesign,
n_vars=self.n_vars,
k=k,
design_scale_bias=design_scale_bias,
mask=mask,
interv_scale=self.interv_scale,
interv_bias=self.interv_bias,
target_weight=target_weight,
)
@property
def fit_stage(self) -> FitStage:
r"""
Prediction mode, see :class:`FitStage` for details
"""
return self._fit_stage
@fit_stage.setter
def fit_stage(self, fit_stage: FitStage | None) -> None:
if fit_stage is not None:
self._predict_mode = None
self.unfreeze_params()
if fit_stage >= FitStage.tune:
self.scaffold.freeze_params()
self.sparse.freeze_params()
self.acyc.freeze_params()
self.kernel.freeze_params()
self.lik.freeze_params()
if fit_stage >= FitStage.ctfact:
topo_gens = self.scaffold.topo_gens()
logger.info(
f"Number of topological generations: "
f"{[len(gens) for gens in topo_gens]}"
)
self._topo_gens = [gens[:-1] for gens in topo_gens]
if fit_stage >= FitStage.design:
self.latent.freeze_params()
self.func.freeze_params()
self.interv_scale.requires_grad_(False)
self.interv_bias.requires_grad_(False)
if self.design is None:
raise ValueError("Design module not initialized")
elif self.design is not None:
self.design = None
else:
self._topo_gens = None
self._fit_stage = fit_stage
@property
def predict_mode(self) -> PredictMode:
r"""
Prediction mode, see :class:`PredictMode` for details
"""
return self._predict_mode
@predict_mode.setter
def predict_mode(self, predict_mode: PredictMode | None) -> None:
if predict_mode is not None:
self._fit_stage = None
if predict_mode >= PredictMode.dsgnerr:
topo_gens = self.scaffold.topo_gens()
logger.info(
f"Number of topological generations: "
f"{[len(gens) for gens in topo_gens]}"
)
self._topo_gens = [
[gen[~torch.isin(gen, self.fixed_vars)] for gen in gens[:-1]]
for gens in topo_gens
]
else:
self._topo_gens = None
self._predict_mode = predict_mode
@property
def prefit(self) -> bool:
r"""
Whether to run prefit on the covariates only
"""
if self.fit_stage == FitStage.discover:
return self._prefit
return False
@prefit.setter
def prefit(self, prefit: bool) -> None:
if prefit and not self.n_covariates:
raise ValueError("Cannot prefit without covariates")
self._prefit = prefit
@property
def topo_gens(self) -> list[list[torch.LongTensor]]:
r"""
Topological generations of the causal graph
"""
if self._topo_gens is None:
raise AttributeError
return self._topo_gens
@property
def fixed_vars(self) -> torch.LongTensor:
r"""
Fixed variables during counterfactual prediction
"""
return self._fixed_vars
@fixed_vars.setter
def fixed_vars(self, fixed_vars: torch.LongTensor | None) -> None:
if fixed_vars is None:
self._fixed_vars = torch.empty(0, dtype=torch.long)
return
if fixed_vars.min() < 0 or fixed_vars.max() >= self.n_vars:
raise ValueError("Fixed variables out of bounds")
self._fixed_vars = fixed_vars
def _coordinate_device(self) -> None:
if self._topo_gens is not None:
self._topo_gens = [
[gen.to(self.interv_scale.device) for gen in gens]
for gens in self._topo_gens
]
@internal
def configure_optimizers(self) -> torch.optim.Optimizer:
opt = getattr(torch.optim, self.opt)(
[
{
"params": self.regular_params(),
"weight_decay": 0.0,
},
{
"params": self.decay_params(),
"weight_decay": self.weight_decay,
},
],
lr=self.lr,
)
return opt
[docs]
def forward(
self,
x: torch.Tensor,
r: torch.Tensor,
s: torch.Tensor,
l: torch.Tensor,
l_: torch.Tensor | None = None,
z: D.Normal | None = None,
oidx: torch.LongTensor | None = None,
) -> tuple[D.Normal, D.Distribution]:
r"""
Forward pass of the model
Parameters
----------
x
Sample data ([n_particles,] batch_size, n_vars)
r
Intervention regime ([n_particles,] batch_size, n_vars)
s
Covariate (batch_size, n_covariates)
l
Library size (batch_size, 1)
l\_
Counterfactual library size (batch_size, 1)
z
Latent variable (n_particles, batch_size, latent_dim)
oidx
Output variable index
Returns
-------
Latent variable (n_particles, batch_size, latent_dim)
Data reconstruction distribution
"""
n_vars = self.n_vars if oidx is None else oidx.numel()
if z is None:
z = self.latent(
torch.zeros_like(r) if self.ablate_latent else r
) # (n_particles, bs, latent_dim)
if self.prefit:
ptr = x.new_zeros(
(self.n_particles, n_vars, x.size(-2), self.scaffold.max_indegree)
)
z_samp = x.new_zeros((self.n_particles, x.size(-2), self.latent_dim))
else:
ptr = self.lik.tone(x, l) # ([n_particles,] bs, n_vars)
ptr = self.scaffold.mask_data(ptr, oidx=oidx)
z_samp = z.rsample() if self.training else z.mean
z_samp = z_samp.unsqueeze(1).expand(-1, n_vars, -1, -1)
s = s.expand(self.n_particles, n_vars, -1, -1)
cov = torch.cat([z_samp, s], dim=-1)
# (n_particles, n_vars, bs, *)
if oidx is None:
func = self.func(ptr, cov)
else:
func = self.func(ptr, cov, slice(None), oidx)
mean, disp = func.permute(3, 0, 2, 1) # (n_particles, bs, n_vars)
interv_scale = (
self.interv_scale if self.design is None else self.design.scale
).unsqueeze(1)
interv_bias = (
self.interv_bias if self.design is None else self.design.bias
).unsqueeze(1)
# (n_particles, 1, n_vars)
if oidx is not None:
interv_scale = interv_scale[..., oidx]
interv_bias = interv_bias[..., oidx]
r = r[..., oidx]
r = torch.zeros_like(r) if self.ablate_interv else r
interv_scale = (interv_scale * r).exp()
interv_bias = interv_bias * r
x_est = self.lik(
mean * interv_scale + interv_bias,
disp,
l if l_ is None else l_,
oidx=oidx,
)
return z, x_est
[docs]
def explain(
self,
x: torch.Tensor, # Factual x
r: torch.Tensor, # Factual r
s: torch.Tensor, # Factual s
l: torch.Tensor, # Factual l
x_: torch.Tensor, # Counterfactual x
r_: torch.Tensor, # Counterfactual r
s_: torch.Tensor, # Counterfactual s
l_: torch.Tensor, # Counterfactual l
) -> tuple[torch.Tensor, ...]:
r"""
Explanation pass of the model
Parameters
----------
x
Factual data ([n_particles,] batch_size, n_vars)
r
Factual intervention regime ([n_particles,] batch_size, n_vars)
s
Factual covariates (batch_size, n_covariates)
l
Factual library size (batch_size, 1)
x\_
Counterfactual data ([n_particles,] batch_size, n_vars)
r\_
Counterfactual intervention regime ([n_particles,] batch_size, n_vars)
s\_
Counterfactual covariates (batch_size, n_covariates)
l\_
Counterfactual library size (batch_size, 1)
Returns
-------
Prediction with all factual components
Prediction with only the counterfactual intervention scaling and bias
Prediction with only the counterfactual covariates
Prediction with only the counterfactual latent variable
Prediction with the counterfactual value of each parent variable
Prediction with all counterfactual components
"""
z = self.latent(r).mean.unsqueeze(1).expand(-1, self.n_vars, -1, -1)
z_ = self.latent(r_).mean.unsqueeze(1).expand(-1, self.n_vars, -1, -1)
s = s.expand(self.n_particles, self.n_vars, -1, -1)
s_ = s_.expand(self.n_particles, self.n_vars, -1, -1)
cov = torch.cat([z, s], dim=-1)
if x_.dim() == 3: # (bs, n_vars, n_particles)
x_ = x_.permute(2, 0, 1) # (n_particles, bs, n_vars)
x = self.lik.tone(x, l)
x_ = tot = self.lik.tone(x_, l_)
ptr = self.scaffold.mask_data(x)
ptr_ = self.scaffold.mask_data(x_)
interv_scale = (
self.interv_scale if self.design is None else self.design.scale
).unsqueeze(1)
interv_bias = (
self.interv_bias if self.design is None else self.design.bias
).unsqueeze(1)
scale = (interv_scale * r).exp()
bias = interv_bias * r
scale_ = (interv_scale * r_).exp()
bias_ = interv_bias * r_
mean, disp = self.func(ptr, cov).permute(3, 0, 2, 1)
nil = self.lik.tone(self.lik(mean * scale + bias, disp, l_).mean, l_)
ctrb_i = self.lik.tone(self.lik(mean * scale_ + bias_, disp, l_).mean, l_)
mean, disp = self.func(ptr, torch.cat([z, s_], dim=-1)).permute(3, 0, 2, 1)
ctrb_s = self.lik.tone(self.lik(mean * scale + bias, disp, l_).mean, l_)
mean, disp = self.func(ptr, torch.cat([z_, s], dim=-1)).permute(3, 0, 2, 1)
ctrb_z = self.lik.tone(self.lik(mean * scale + bias, disp, l_).mean, l_)
ctrb_ptr = []
for i in range(ptr.size(-1)):
ptr_use = ptr.clone()
ptr_use[..., i] = ptr_[..., i] # Plug in parents one by one
mean, disp = self.func(ptr_use, cov).permute(3, 0, 2, 1)
ctrb_ptr.append(
self.lik.tone(self.lik(mean * scale + bias, disp, l_).mean, l_)
)
ctrb_ptr = torch.stack(ctrb_ptr, dim=-1)
return nil, ctrb_i, ctrb_s, ctrb_z, ctrb_ptr, tot
[docs]
def cascade(
self,
x: torch.Tensor,
r: torch.Tensor,
s: torch.Tensor,
l: torch.Tensor,
l_: torch.Tensor | None = None,
z: D.Normal | None = None,
) -> tuple[D.Normal, D.Distribution]:
r"""
Cascade pass of the model
Parameters
----------
x
Sample data ([n_particles,] batch_size, n_vars)
r
Intervention regime ([n_particles,] batch_size, n_vars)
s
Covariate (batch_size, n_covariates)
l
Library size (batch_size, 1)
l\_
Counterfactual library size (batch_size, 1)
z
Latent variable (n_particles, batch_size, latent_dim)
Returns
-------
Latent variable (n_particles, batch_size, latent_dim)
Data reconstruction distribution
"""
if x.dim() == 2:
x = x.unsqueeze(0).expand(self.n_particles, -1, -1)
remap = torch.empty(self.n_vars, dtype=torch.long, device=x.device)
empty_gen = torch.empty(0, dtype=torch.long, device=x.device)
if self.ablate_graph:
return self(x, r, s, l, l_=l_, z=z) # (n_particles, bs, *)
for gens in zip_longest(*self.topo_gens, fillvalue=empty_gen):
oidx = torch.cat(gens).unique()
remap[oidx] = torch.arange(oidx.numel(), device=x.device)
z, x_est = self(x, r, s, l, z=z, oidx=oidx) # (n_particles, bs, *)
x_mean = x_est.mean
x = x.clone() # Expanded dim becomes independent
for i, gen in enumerate(gens):
x[i, :, gen] = x_mean[i, :, remap[gen]]
return self(x, r, s, l, l_=l_, z=z) # (n_particles, bs, *)
[docs]
def compute_lik(self, batch: Iterable[torch.Tensor]) -> tuple[torch.Tensor, ...]:
r"""
Compute likelihood terms from a minibatch
Parameters
----------
batch
Minibatch of data
Returns
-------
Negative log-likelihood
Negative log-prior
Latent KL divergence
"""
fit_stage = self.fit_stage
if fit_stage < FitStage.ctfact:
x, r, s, l, w = batch
else:
x, r, s, l, w, x_, r_, s_, l_, _ = batch
if fit_stage == FitStage.design:
r = self.design.rsample(r_.size(0))
_, x_est = self.cascade(x, r, s_, l, l_=l_)
x_est = self.lik.tone(x_est.mean, l_) # (n_particles, bs, n_vars)
x_tgt = self.lik.tone(x_, l_) # (bs, n_vars)
mse = self.design.loss(x_est, x_tgt) # (n_particles, bs)
nll = (mse * w).mean(dim=-1) # (n_particles,)
nlp = nll.new_zeros(())
kl = nll.new_zeros(())
else:
z, x_est = self(x, r, s, l) # (n_particles, bs, *)
log_lik = x_est.log_prob(x)
log_prior = self.lik.log_prior(x_est)
nll = (log_lik.mean(dim=-1) * w).mean(dim=-1).neg() # (n_particles,)
nlp = (log_prior.mean(dim=-1) * w).mean(dim=-1).neg() # (n_particles,)
if fit_stage == FitStage.ctfact:
_, x_ctfact = self.cascade(x, r_, s_, l, l_=l_)
log_lik_ctfact = x_ctfact.log_prob(x_)
log_prior_ctfact = self.lik.log_prior(x_ctfact)
nll_ctfact = (log_lik_ctfact.mean(dim=-1) * w).mean(dim=-1).neg()
nlp_ctfact = (log_prior_ctfact.mean(dim=-1) * w).mean(dim=-1).neg()
nll = 0.5 * nll + 0.5 * nll_ctfact
nlp = 0.5 * nlp + 0.5 * nlp_ctfact
kl = (
D.kl_divergence(z, self.latent.prior()).mean()
if self.latent_dim
else nll.new_zeros(())
) # (n_particles,)
return nll, nlp, kl
[docs]
def compute_prior(self) -> tuple[torch.Tensor, torch.Tensor]:
r"""
Compute the prior energy terms
Returns
-------
Sparse prior energy
Acyclic prior energy
"""
if "compute_prior" in self.cache:
return self.cache["compute_prior"]
sparse_enrg = self.sparse.energy(self.scaffold)
acyc_enrg = self.acyc.energy(self.scaffold)
self.cache["compute_prior"] = (sparse_enrg.detach(), acyc_enrg.detach())
return sparse_enrg, acyc_enrg
[docs]
def compute_kernel(self) -> torch.Tensor:
r"""
Compute the SVGD kernel
Returns
-------
SVGD kernel
"""
if "compute_kernel" in self.cache:
return self.cache["compute_kernel"]
kernel = self.kernel(self.scaffold.prob, self.scaffold.prob.detach())
self.cache["compute_kernel"] = kernel.detach()
return kernel
[docs]
def training_step(
self, batch: Iterable[torch.Tensor], batch_idx: int
) -> torch.Tensor:
r"""
Training step for a minibatch
"""
nll, nlp, kl = self.compute_lik(batch)
sparse_enrg, acyc_enrg = self.compute_prior()
kernel = self.compute_kernel()
lik_enrg = nll + nlp + self.beta * kl
if not self.scaffold.frozen and not self.prefit:
EXP_AVG = (
0.0
if torch.allclose(self.lik_grad_avg, torch.as_tensor(0.0))
else self.EXP_AVG
)
self.lik_grad_avg = (
EXP_AVG * self.lik_grad_avg
+ (1 - EXP_AVG)
* torch.autograd.grad(
lik_enrg.sum(), self.scaffold.logit, retain_graph=True
)[0]
)
self.sparse_grad_avg = (
EXP_AVG * self.sparse_grad_avg
+ (1 - EXP_AVG)
* torch.autograd.grad(
sparse_enrg.sum(), self.scaffold.logit, retain_graph=True
)[0]
)
self.acyc_grad_avg = (
EXP_AVG * self.acyc_grad_avg
+ (1 - EXP_AVG)
* torch.autograd.grad(
acyc_enrg.sum(), self.scaffold.logit, retain_graph=True
)[0]
)
self.log_dict(
{
"grad/lik_norm": self.lik_grad_avg.norm(dim=1).mean(),
"grad/sparse_norm": self.sparse_grad_avg.norm(dim=1).mean(),
"grad/acyc_norm": self.acyc_grad_avg.norm(dim=1).mean(),
"grad/kernel_norm": self.kernel_grad_avg.norm(dim=1).mean(),
},
sync_dist=True,
)
prior_enrg = self.lam * sparse_enrg + self.alpha * acyc_enrg # (n_particles,)
post_enrg = lik_enrg + prior_enrg
self.log_dict(
{
"train/nll": nll.mean(),
"train/nlp": nlp.mean(),
"train/kl": kl.mean(),
"train/lik_enrg": lik_enrg.mean(),
"train/sparse_enrg": sparse_enrg.mean(),
"train/acyc_enrg": acyc_enrg.mean(),
"train/prior_enrg": prior_enrg.mean(),
"train/post_enrg": post_enrg.mean(),
"train/kernel": kernel.mean(),
},
sync_dist=True,
) # OK: Consider how to properly log SVGD
self.backward(post_enrg, kernel)
if (batch_idx + 1) % self.accumulate_grad_batches == 0:
opt = self.optimizers()
if self.accumulate_grad_batches > 1:
seen = set()
for group in opt.param_groups:
for p in group["params"]:
if p.grad is None or id(p) in seen:
continue
p.grad.div_(self.accumulate_grad_batches)
seen.add(id(p))
opt.step()
opt.zero_grad()
[docs]
def backward(self, enrg: torch.Tensor, kernel: torch.Tensor) -> None:
r"""
Implementation of the main SVGD logic
.. caution::
This implementation is only valid for symmetric and
translation-invariant kernels.
"""
self.scaffold.zero_grad(backup=True)
enrg_sum = enrg.sum()
enrg_sum.backward(retain_graph=True)
if self.scaffold.frozen:
return
enrg_grad, self.scaffold.logit.grad = self.scaffold.logit.grad, None
kernel_grad = torch.autograd.grad(
kernel.sum(), self.scaffold.logit, allow_unused=True # Allow KroneckerDelta
)[0]
if kernel_grad is None:
kernel_grad = torch.zeros_like(self.scaffold.logit)
EXP_AVG = (
0.0
if torch.allclose(self.kernel_grad_avg, torch.as_tensor(0.0))
else self.EXP_AVG
)
self.kernel_grad_avg = (
EXP_AVG * self.kernel_grad_avg + (1 - EXP_AVG) * kernel_grad
)
logit_grad = (
kernel.detach().matmul(enrg_grad) + self.gamma * kernel_grad
) / self.n_particles
self.scaffold.zero_grad(backup=False)
self.scaffold.logit.backward(gradient=logit_grad, retain_graph=True)
self.scaffold.accumulate_grad()
[docs]
def validation_step(
self, batch: Iterable[torch.Tensor], batch_idx: int
) -> torch.Tensor:
r"""
Validation step for a minibatch
"""
nll, nlp, kl = self.compute_lik(batch)
sparse_enrg, acyc_enrg = self.compute_prior()
kernel = self.compute_kernel()
lik_enrg = nll + nlp + self.beta * kl
prior_enrg = self.lam * sparse_enrg + self.alpha * acyc_enrg # (n_particles,)
post_enrg = lik_enrg + prior_enrg
self.log_dict(
{
"val/nll": nll.mean(),
"val/nlp": nlp.mean(),
"val/kl": kl.mean(),
"val/lik_enrg": lik_enrg.mean(),
"val/sparse_enrg": sparse_enrg.mean(),
"val/acyc_enrg": acyc_enrg.mean(),
"val/prior_enrg": prior_enrg.mean(),
"val/post_enrg": post_enrg.mean(),
"val/kernel": kernel.mean(),
"hparam/lam": self.lam,
"hparam/alpha": self.alpha,
"hparam/gamma": self.gamma,
},
sync_dist=True,
) # OK: Consider how to properly log SVGD
[docs]
def predict_step(
self, batch: Iterable[torch.Tensor], batch_idx: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | torch.Tensor:
r"""
Prediction step for a minibatch
"""
predict_mode = self.predict_mode
if predict_mode == PredictMode.recon:
x, r, s, l, _ = batch
z, x_est = self(x, r, s, l)
return (
z.mean.cpu(),
z.variance.sqrt().cpu(),
x_est.mean.cpu(),
x_est.variance.sqrt().cpu(),
self.lik.get_disp(x_est).cpu(),
)
if predict_mode == PredictMode.jac:
x, r, s, l, _ = batch
x = x.expand(self.n_particles, -1, -1)
with torch.enable_grad():
x = x.requires_grad_()
_, x_est = self(x, r, s, l)
x_mean = x_est.mean # (n_particles, bs, n_vars)
grad_outputs = torch.eye(
x_mean.size(-1), dtype=x_mean.dtype, device=x_mean.device
) # (n_vars, n_vars)
grad_outputs = grad_outputs.view(
grad_outputs.size(0), 1, 1, grad_outputs.size(1)
).expand(-1, x_mean.size(0), x_mean.size(1), -1)
# (n_vars, n_particles, bs, n_vars)
return (
torch.stack(
[
torch.autograd.grad(
x_mean,
x,
grad_outputs=g, # (n_particles, bs, n_vars)
retain_graph=True,
)[0].cpu()
for g in grad_outputs
],
dim=-2,
),
)
# This is slower than the is_grads_batched approach but
# much more memory-efficient. The latter easily runs OOM.
if predict_mode == PredictMode.explain:
x, r, s, l, _, x_, r_, s_, l_, _ = batch
nil, ctrb_i, ctrb_s, ctrb_z, ctrb_ptr, tot = self.explain(
x, r, s, l, x_, r_, s_, l_
)
return (
nil.cpu(),
ctrb_i.cpu(),
ctrb_s.cpu(),
ctrb_z.cpu(),
ctrb_ptr.cpu(),
tot.cpu(),
)
if predict_mode == PredictMode.dsgnerr:
x, r, s, l, _, x_, _, s_, l_, _ = batch
_, x_est = self.cascade(x, r, s_, l, l_=l_)
x_est = self.lik.tone(x_est.mean, l_) # (n_particles, bs, n_vars)
x_tgt = self.lik.tone(x_, l_) # (bs, n_vars)
return (self.design.loss(x_est, x_tgt),) # (n_particles, bs)
# predict_mode >= PredictMode.ctmean:
x, r, s, l, _ = batch
z, x_est = self.cascade(x, r, s, l)
return (
z.mean.cpu(),
z.variance.sqrt().cpu(),
(
x_est.mean.cpu()
if predict_mode == PredictMode.ctmean
else x_est.sample().cpu() # PredictMode.ctsamp
),
x_est.variance.sqrt().cpu(),
self.lik.get_disp(x_est).cpu(),
)
@internal
def on_train_batch_start(
self, batch: Iterable[torch.Tensor], batch_idx: int
) -> None:
if self.fit_stage == FitStage.design:
self.eval()
comb = self.design.comb_lists
topk = self.design.logits.data.topk(min(10, self.design.logits.size(0)))
top_str = " | ".join(
f"{','.join(sorted(self.vars[comb[idx]]))} ({val:.2f})"
for idx, val in zip(topk.indices.tolist(), topk.values.tolist())
)
self.logger.experiment.add_text(
"design/top", top_str, global_step=self.global_step
)
elif self.fit_stage == FitStage.discover:
self.scaffold.clear_cached_properties()
self.cache.clear()
@internal
def on_validation_start(self) -> None:
if self.fit_stage == FitStage.discover:
self.scaffold.clear_cached_properties()
self.cache.clear()
if self.log_adj in (LogAdj.mean, LogAdj.both):
self.logger.experiment.add_image(
"adj/mean",
self.scaffold.mean_adj.detach().cpu().float().to_dense(),
global_step=self.global_step,
dataformats="HW",
)
if self.log_adj in (LogAdj.particles, LogAdj.both):
adj = self.scaffold.adj.detach().cpu().float().to_dense().permute(2, 0, 1)
for i, particle in enumerate(adj):
self.logger.experiment.add_image(
f"adj/particle_{i}",
particle,
global_step=self.global_step,
dataformats="HW",
)
@internal
def on_fit_start(self) -> None:
self._coordinate_device()
self.scaffold.clear_cached_properties()
self.cache.clear()
@internal
def on_fit_end(self) -> None:
self.scaffold.clear_cached_properties()
self.cache.clear()
self.reset_properties()
torch.cuda.empty_cache()
@internal
def on_predict_start(self) -> None:
self._coordinate_device()
self.scaffold.clear_cached_properties()
self.cache.clear()
@internal
def on_predict_end(self) -> None:
self.scaffold.clear_cached_properties()
self.cache.clear()
self.reset_properties()
torch.cuda.empty_cache()
[docs]
def prune(self) -> None:
r"""
Prune the scaffold and structural equations accordingly
"""
old_map = {(i, j): k for i, j, k in self.scaffold.idx.t().tolist()}
mask = self.scaffold.prune()
self.lik_grad_avg = self.lik_grad_avg[:, mask]
self.sparse_grad_avg = self.sparse_grad_avg[:, mask]
self.acyc_grad_avg = self.acyc_grad_avg[:, mask]
self.kernel_grad_avg = self.kernel_grad_avg[:, mask]
new_map = {(i, j): k for i, j, k in self.scaffold.idx.t().tolist()}
self.hparams["scaffold_kws"]["eidx"] = self.scaffold.idx[:2].cpu()
trailing = self.latent_dim + self.n_covariates
old = self.func.layers[0]
new = MultiLinear(
in_features=self.scaffold.max_indegree + trailing,
out_features=old.out_features,
multi_dims=(self.n_particles, self.n_vars),
)
copy_like(old.bias, new.bias)
init.zeros_(new.weight)
new.weight.data = new.weight.data.to(
device=old.weight.device, dtype=old.weight.dtype
)
for (i, j), k in new_map.items():
new.weight.data[:, j, :, k] = old.weight.data[:, j, :, old_map[(i, j)]]
if trailing:
new.weight.data[:, :, :, -trailing:] = old.weight.data[:, :, :, -trailing:]
self.func.layers[0] = new
@internal
def get_extra_state(self) -> dict[str, Any]:
return {
"lam": self.lam,
"alpha": self.alpha,
"gamma": self.gamma,
"opt": self.opt,
"lr": self.lr,
"weight_decay": self.weight_decay,
"accumulate_grad_batches": self.accumulate_grad_batches,
"_fit_stage": None if self._fit_stage is None else int(self._fit_stage),
"_predict_mode": (
None if self._predict_mode is None else int(self._predict_mode)
),
"_prefit": self._prefit,
"_topo_gens": self._topo_gens,
"_fixed_vars": self._fixed_vars,
"log_adj": None if self.log_adj is None else int(self.log_adj),
**super().get_extra_state(),
}
@internal
def set_extra_state(self, state: dict[str, Any]) -> None:
self.lam = state.pop("lam")
self.alpha = state.pop("alpha")
self.gamma = state.pop("gamma")
self.opt = state.pop("opt")
self.lr = state.pop("lr")
self.weight_decay = state.pop("weight_decay")
self.accumulate_grad_batches = state.pop("accumulate_grad_batches")
_fit_stage = state.pop("_fit_stage")
self._fit_stage = None if _fit_stage is None else FitStage(_fit_stage)
_predict_mode = state.pop("_predict_mode")
self._predict_mode = (
None if _predict_mode is None else PredictMode(_predict_mode)
)
self._prefit = state.pop("_prefit")
self._topo_gens = state.pop("_topo_gens")
self._fixed_vars = state.pop("_fixed_vars")
log_adj = state.pop("log_adj")
self.log_adj = None if log_adj is None else LogAdj(log_adj)
super().set_extra_state(state)
[docs]
class DiscoverScheduler(callbacks.EarlyStopping):
r"""
Hyperparameter scheduler for causal discovery
Parameters
----------
monitor
Loss to be monitored
constraint
Loss that specifies the constraint
patience
Number of checks with no improvement after which training will be stopped
tolerance
Maximal tolerance of constraint violation to end the scheduler
lam
Sparse penalty rate (:math:`\eta_\lambda` in paper)
alpha
Acyclic penalty rate (:math:`\eta_\alpha` in paper)
gamma
Kernel gradient rate (:math:`\eta_\gamma`)
**kwargs
Additional keyword arguments are passed to
:class:`~lightning.pytorch.callbacks.EarlyStopping`
"""
inf = {"min": torch.tensor(torch.inf), "max": torch.tensor(-torch.inf)}
def __init__(
self,
monitor: str,
constraint: str,
patience: int,
tolerance: float = None,
lam: float = None,
alpha: float = None,
gamma: float = None,
**kwargs,
) -> None:
if kwargs.get("check_on_train_epoch_end", False):
raise ValueError(
"Only supports checking on validation epoch end"
) # pragma: no cover
kwargs["check_on_train_epoch_end"] = False
super().__init__(monitor, patience=patience, **kwargs)
self.constraint = constraint
self.tolerance = tolerance
self.lam = lam
self.alpha = alpha
self.gamma = gamma
self.stall_patience = patience
self.stall_count = 0
self.prefit = None
self.violation = None
self.trigger_flag = None
self.min_violation = float("inf")
@property
def state_key(self) -> str:
return self._generate_state_key(
monitor=self.monitor, constraint=self.constraint, mode=self.mode
)
@internal
def state_dict(self) -> dict[str, Any]:
return {
"tolerance": self.tolerance,
"lam": self.lam,
"alpha": self.alpha,
"gamma": self.gamma,
"stall_patience": self.stall_patience,
"stall_count": self.stall_count,
"prefit": self.prefit,
"violation": self.violation,
"trigger_flag": self.trigger_flag,
"min_violation": self.min_violation,
**super().state_dict(),
}
@internal
def load_state_dict(self, state_dict: dict[str, Any]) -> None: # pragma: no cover
self.tolerance = state_dict.pop("tolerance")
self.lam = state_dict.pop("lam")
self.alpha = state_dict.pop("alpha")
self.gamma = state_dict.pop("gamma")
self.stall_patience = state_dict.pop("stall_patience")
self.stall_count = state_dict.pop("stall_count")
self.prefit = state_dict.pop("prefit")
self.violation = state_dict.pop("violation")
self.trigger_flag = state_dict.pop("trigger_flag")
self.min_violation = state_dict.pop("min_violation")
super().load_state_dict(state_dict)
[docs]
def on_validation_end(self, trainer: Trainer, pl_module: CausalNetwork) -> None:
r"""
Main logic of the scheduler
.. note::
The scheduler will adjust the hyperparameters of the model according
to the gradient norms of the likelihood, sparse, acyclic, and kernel
gradients, each time the early stopping criteria is met. The
adjustment is based on the ratio of the likelihood gradient norm to
the other gradient norms. The scheduler continues until the
constraint is satisfied or constraint violation stops improving for
a consecutive ``patience`` times.
"""
self.prefit = pl_module.prefit
self.violation = trainer.callback_metrics[self.constraint]
super().on_validation_end(trainer, pl_module)
pl_module.prefit = self.prefit
if self.trigger_flag:
lik_grad_norm = pl_module.lik_grad_avg.norm(dim=1).mean().item()
sparse_grad_norm = pl_module.sparse_grad_avg.norm(dim=1).mean().item()
acyc_grad_norm = pl_module.acyc_grad_avg.norm(dim=1).mean().item()
kernel_grad_norm = pl_module.kernel_grad_avg.norm(dim=1).mean().item()
pl_module.lam = self.lam * lik_grad_norm / (sparse_grad_norm + EPS)
pl_module.alpha += self.alpha * lik_grad_norm / (acyc_grad_norm + EPS)
if pl_module.n_particles > 1:
pl_module.gamma = self.gamma * lik_grad_norm / (kernel_grad_norm + EPS)
self.best_score = self.inf[self.mode]
if trainer.checkpoint_callback is not None:
inf = self.inf[trainer.checkpoint_callback.mode]
for m in trainer.checkpoint_callback.best_k_models:
trainer.checkpoint_callback.best_k_models[m] = inf
trainer.checkpoint_callback.kth_value = inf
trainer.checkpoint_callback.best_model_score = inf
trainer.checkpoint_callback.skip_once = True
if self.verbose:
console.print(
f"Discover scheduler triggered: "
f"lam = [hl]{pl_module.lam:.2e}[/hl], "
f"alpha = [hl]{pl_module.alpha:.2e}[/hl], "
f"gamma = [hl]{pl_module.gamma:.2e}[/hl]"
)
self.trigger_flag = False
def _evaluate_stopping_criteria(
self, current: torch.Tensor
) -> tuple[bool, str | None]:
should_stop, reason = super()._evaluate_stopping_criteria(current)
self.trigger_flag = False
if not should_stop:
return should_stop, reason
if self.prefit:
self.prefit, should_stop = False, False
reason = "Prefit concluded"
return should_stop, reason
improve_ratio = (self.min_violation - self.violation) / self.min_violation
self.min_violation = min(self.violation, self.min_violation)
if improve_ratio < 0.01:
self.stall_count += 1
console.print(f"Discover scheduler stall #{self.stall_count}")
if self.stall_count >= self.stall_patience:
reason = "Violation stalled"
return should_stop, reason
else:
self.stall_count = 0
if self.min_violation > self.tolerance:
self.trigger_flag, should_stop = True, False
reason = "Constraint unsatisfied"
else:
reason = "Constraint satisfied"
return should_stop, reason
[docs]
class ModelCheckpoint(callbacks.ModelCheckpoint):
r"""
Custom model checkpoint callback that can be configured to skip saving the
model once
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.skip_once = False
@internal
def on_validation_end(self, trainer: Trainer, pl_module: CausalNetwork) -> None:
if self.skip_once:
logger.debug("Skipping model checkpoint.")
self.skip_once = False
return
super().on_validation_end(trainer, pl_module)
@internal
def state_dict(self) -> dict[str, Any]:
return {
"skip_once": self.skip_once,
**super().state_dict(),
}
@internal
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.skip_once = state_dict.pop("skip_once")
super().load_state_dict(state_dict)
[docs]
class PredictionWriter(callbacks.BasePredictionWriter):
r"""
Custom prediction writer to enable multi-device prediction
"""
def __init__(self, output_dir: os.PathLike) -> None:
super().__init__(write_interval="epoch")
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
@internal
def write_on_epoch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
predictions: Any,
batch_indices: Any,
) -> None:
torch.save(
predictions,
self.output_dir / f"pred{trainer.global_rank}.pt",
)
torch.save(
batch_indices,
self.output_dir / f"ind{trainer.global_rank}.pt",
)