cascade.nn.EmbLatent

class cascade.nn.EmbLatent(n_particles, latent_dim, vmap, emb=None)[source]

Bases: Latent

Intervention latent module encoding from fixed embeddings

Parameters:
  • n_particles (int) – Number of SVGD particles

  • latent_dim (int) – Dimensionality of the latent variable

  • vmap (LongTensor) – Variable index mapping with the parent module CausalNetwork

  • emb (Tensor) – Fixed embedding tensor

Methods

forward