cascade.nn.Latent

class cascade.nn.Latent(n_particles, latent_dim, vmap, **kwargs)[source]

Bases: Module

Interventional latent module

Parameters:
  • n_particles (int) – Number of SVGD particles

  • latent_dim (int) – Dimensionality of the latent variable

  • vmap (LongTensor) – Variable index mapping with the parent module CausalNetwork

Methods

forward

prior