cascade.nn.GCNLatent

class cascade.nn.GCNLatent(n_particles, latent_dim, vmap, eidx=None, ewt=None, emb_dim=None, n_layers=1)[source]

Bases: Latent

Intervention latent module encoding from a graph

Parameters:
  • n_particles (int) – Number of SVGD particles

  • latent_dim (int) – Dimensionality of the latent variable

  • vmap (LongTensor) – Variable index mapping with the parent module CausalNetwork

  • eidx (LongTensor) – Graph edge index of shape (2, n_edges)

  • ewt (FloatTensor) – Graph edge weight of shape (n_edges,)

  • emb_dim (int) – Dimensionality of the learnable node embedding

  • n_layers (int) – Number of graph convolution layers

Methods

forward

normalize_edges

vertex_degrees

Attributes

INIT_STD