cascade.core.CausalNetwork.forward
- CausalNetwork.forward(x, r, s, l, l_=None, z=None, oidx=None)[source]
Forward pass of the model
- Parameters:
x (
Tensor) – Sample data ([n_particles,] batch_size, n_vars)r (
Tensor) – Intervention regime ([n_particles,] batch_size, n_vars)s (
Tensor) – Covariate (batch_size, n_covariates)l (
Tensor) – Library size (batch_size, 1)l_ (
Tensor|None) – Counterfactual library size (batch_size, 1)z (
Normal|None) – Latent variable (n_particles, batch_size, latent_dim)oidx (
LongTensor|None) – Output variable index
- Return type:
- Returns:
Latent variable (n_particles, batch_size, latent_dim)
Data reconstruction distribution