cascade.core.CausalNetwork.forward

CausalNetwork.forward(x, r, s, l, l_=None, z=None, oidx=None)[source]

Forward pass of the model

Parameters:
  • x (Tensor) – Sample data ([n_particles,] batch_size, n_vars)

  • r (Tensor) – Intervention regime ([n_particles,] batch_size, n_vars)

  • s (Tensor) – Covariate (batch_size, n_covariates)

  • l (Tensor) – Library size (batch_size, 1)

  • l_ (Tensor | None) – Counterfactual library size (batch_size, 1)

  • z (Normal | None) – Latent variable (n_particles, batch_size, latent_dim)

  • oidx (LongTensor | None) – Output variable index

Return type:

tuple[Normal, Distribution]

Returns:

  • Latent variable (n_particles, batch_size, latent_dim)

  • Data reconstruction distribution