cascade.nn.Edgewise

class cascade.nn.Edgewise(n_vars, n_particles, eidx, tau=10.0)[source]

Bases: Scaffold

Edgewise parameterized edge logits

Parameters:
  • n_vars (int) – Number of variables in the graph

  • n_particles (int) – Number of SVGD particles

  • eidx (LongTensor) – Scaffold edge indices of shape (2, n_edges)

  • tau (float) – Gumbel-sigmoid temperature

Methods

compute_logit

prune

Attributes

INIT_STD