cascade.core.CausalNetwork
- class cascade.core.CausalNetwork(n_vars, n_particles, n_covariates, n_layers, hidden_dim, latent_dim, dropout, beta, scaffold_mod, sparse_mod, acyc_mod, latent_mod, lik_mod, kernel_mod, scaffold_kws=None, sparse_kws=None, acyc_kws=None, latent_kws=None, lik_kws=None, kernel_kws=None, design=None)[source]
Bases:
LightningModule,ModuleCausal discovery neural network
- Parameters:
n_vars (
int) – Number of variables to modeln_particles (
int) – Number of SVGD particlesn_covariates (
int) – Dimension of covariatesn_layers (
int) – Number of MLP layers in the structural equationshidden_dim (
int) – MLP hidden layer dimension in the structural equationslatent_dim (
int) – Dimension of the latent variabledropout (
float) – Dropout ratebeta (
float) – KL weight of the latent variablescaffold_mod (
str) – Scaffold graph module, must be one of {“Edgewise”, “Bilinear”}sparse_mod (
str) – Sparse prior module, must be one of {“L1”, “ScaleFree”}acyc_mod (
str) – Acyclic prior module, must be one of {“TrExp”, “SpecNorm”, “LogDet”}latent_mod (
str) – Latent module, must be one of {“NilLatent”, “EmbLatent”, “GCNLatent”}lik_mod (
str) – Causal likelihood module, must be one of {“Normal”, “NegBin”}kernel_mod (
str) – SVGD kernel module, must be one of {“KroneckerDelta”, “RBF”}scaffold_kws (
Mapping[str,Any] |None) – Keyword arguments to the scaffold graph module, seeEdgewiseorBilinearfor detailssparse_kws (
Mapping[str,Any] |None) – Keyword arguments to the sparse prior module, seeL1orScaleFreefor detailsacyc_kws (
Mapping[str,Any] |None) – Keyword arguments to the acyclic prior module, seeTrExp,SpecNorm, orLogDetfor detailslatent_kws (
Mapping[str,Any] |None) – Keyword arguments to the latent module, seeNilLatent,EmbLatent, orGCNLatentfor detailslik_kws (
Mapping[str,Any] |None) – Keyword arguments to the causal likelihood module, seeNormalorNegBinfor detailskernel_kws (
Mapping[str,Any] |None) – Keyword arguments to the SVGD kernel module, seeKroneckerDeltaorRBFfor detailsdesign (
IntervDesign|None) – Optional intervention design module, seeIntervDesignfor details
Methods
Implementation of the main SVGD logic
Cascade pass of the model
Compute the SVGD kernel
Compute likelihood terms from a minibatch
Compute the prior energy terms
Explanation pass of the model
Forward pass of the model
Prediction step for a minibatch
Prune the scaffold and structural equations accordingly
Set the design module
Training step for a minibatch
Validation step for a minibatch
Attributes
EXP_AVGfit_stagePrediction mode, see
FitStagefor detailsfixed_varsFixed variables during counterfactual prediction
predict_modePrediction mode, see
PredictModefor detailsprefitWhether to run prefit on the covariates only
topo_gensTopological generations of the causal graph