cascade.core.CausalNetwork.explain

CausalNetwork.explain(x, r, s, l, x_, r_, s_, l_)[source]

Explanation pass of the model

Parameters:
  • x (Tensor) – Factual data ([n_particles,] batch_size, n_vars)

  • r (Tensor) – Factual intervention regime ([n_particles,] batch_size, n_vars)

  • s (Tensor) – Factual covariates (batch_size, n_covariates)

  • l (Tensor) – Factual library size (batch_size, 1)

  • x_ (Tensor) – Counterfactual data ([n_particles,] batch_size, n_vars)

  • r_ (Tensor) – Counterfactual intervention regime ([n_particles,] batch_size, n_vars)

  • s_ (Tensor) – Counterfactual covariates (batch_size, n_covariates)

  • l_ (Tensor) – Counterfactual library size (batch_size, 1)

Return type:

tuple[Tensor, ...]

Returns:

  • Prediction with all factual components

  • Prediction with only the counterfactual intervention scaling and bias

  • Prediction with only the counterfactual covariates

  • Prediction with only the counterfactual latent variable

  • Prediction with the counterfactual value of each parent variable

  • Prediction with all counterfactual components