cascade.core.CausalNetwork

class cascade.core.CausalNetwork(n_vars, n_particles, n_covariates, n_layers, hidden_dim, latent_dim, dropout, beta, scaffold_mod, sparse_mod, acyc_mod, latent_mod, lik_mod, kernel_mod, scaffold_kws=None, sparse_kws=None, acyc_kws=None, latent_kws=None, lik_kws=None, kernel_kws=None, design=None)[source]

Bases: LightningModule, Module

Causal discovery neural network

Parameters:
  • n_vars (int) – Number of variables to model

  • n_particles (int) – Number of SVGD particles

  • n_covariates (int) – Dimension of covariates

  • n_layers (int) – Number of MLP layers in the structural equations

  • hidden_dim (int) – MLP hidden layer dimension in the structural equations

  • latent_dim (int) – Dimension of the latent variable

  • dropout (float) – Dropout rate

  • beta (float) – KL weight of the latent variable

  • scaffold_mod (str) – Scaffold graph module, must be one of {“Edgewise”, “Bilinear”}

  • sparse_mod (str) – Sparse prior module, must be one of {“L1”, “ScaleFree”}

  • acyc_mod (str) – Acyclic prior module, must be one of {“TrExp”, “SpecNorm”, “LogDet”}

  • latent_mod (str) – Latent module, must be one of {“NilLatent”, “EmbLatent”, “GCNLatent”}

  • lik_mod (str) – Causal likelihood module, must be one of {“Normal”, “NegBin”}

  • kernel_mod (str) – SVGD kernel module, must be one of {“KroneckerDelta”, “RBF”}

  • scaffold_kws (Mapping[str, Any] | None) – Keyword arguments to the scaffold graph module, see Edgewise or Bilinear for details

  • sparse_kws (Mapping[str, Any] | None) – Keyword arguments to the sparse prior module, see L1 or ScaleFree for details

  • acyc_kws (Mapping[str, Any] | None) – Keyword arguments to the acyclic prior module, see TrExp, SpecNorm, or LogDet for details

  • latent_kws (Mapping[str, Any] | None) – Keyword arguments to the latent module, see NilLatent, EmbLatent, or GCNLatent for details

  • lik_kws (Mapping[str, Any] | None) – Keyword arguments to the causal likelihood module, see Normal or NegBin for details

  • kernel_kws (Mapping[str, Any] | None) – Keyword arguments to the SVGD kernel module, see KroneckerDelta or RBF for details

  • design (IntervDesign | None) – Optional intervention design module, see IntervDesign for details

Methods

backward

Implementation of the main SVGD logic

cascade

Cascade pass of the model

compute_kernel

Compute the SVGD kernel

compute_lik

Compute likelihood terms from a minibatch

compute_prior

Compute the prior energy terms

explain

Explanation pass of the model

forward

Forward pass of the model

predict_step

Prediction step for a minibatch

prune

Prune the scaffold and structural equations accordingly

reset_parameters

reset_properties

set_design

Set the design module

training_step

Training step for a minibatch

validation_step

Validation step for a minibatch

Attributes

EXP_AVG

fit_stage

Prediction mode, see FitStage for details

fixed_vars

Fixed variables during counterfactual prediction

predict_mode

Prediction mode, see PredictMode for details

prefit

Whether to run prefit on the covariates only

topo_gens

Topological generations of the causal graph