cascade.nn.GCNLatent
- class cascade.nn.GCNLatent(n_particles, latent_dim, vmap, eidx=None, ewt=None, emb_dim=None, n_layers=1)[source]
Bases:
LatentIntervention latent module encoding from a graph
- Parameters:
n_particles (
int) – Number of SVGD particleslatent_dim (
int) – Dimensionality of the latent variablevmap (
LongTensor) – Variable index mapping with the parent moduleCausalNetworkeidx (
LongTensor) – Graph edge index of shape (2, n_edges)ewt (
FloatTensor) – Graph edge weight of shape (n_edges,)emb_dim (
int) – Dimensionality of the learnable node embeddingn_layers (
int) – Number of graph convolution layers
Methods
Attributes
INIT_STD