cascade.core.CausalNetwork.cascade
- CausalNetwork.cascade(x, r, s, l, l_=None, z=None)[source]
Cascade pass of the model
- Parameters:
x (
Tensor) – Sample data ([n_particles,] batch_size, n_vars)r (
Tensor) – Intervention regime ([n_particles,] batch_size, n_vars)s (
Tensor) – Covariate (batch_size, n_covariates)l (
Tensor) – Library size (batch_size, 1)l_ (
Tensor|None) – Counterfactual library size (batch_size, 1)z (
Normal|None) – Latent variable (n_particles, batch_size, latent_dim)
- Return type:
- Returns:
Latent variable (n_particles, batch_size, latent_dim)
Data reconstruction distribution