cascade.nn.Scaffold

class cascade.nn.Scaffold(n_vars, n_particles, eidx, tau=10.0, **kwargs)[source]

Bases: Module

Abstract graph scaffold

Parameters:
  • n_vars (int) – Number of variables in the graph

  • n_particles (int) – Number of SVGD particles

  • eidx (LongTensor) – Scaffold edge indices of shape (2, n_edges)

  • tau (float) – Gumbel-sigmoid temperature

Methods

accumulate_grad

compute_logit

construct_sparse_tensor

export_graph

import_graph

make_k

mask_data

prune

topo_gens

zero_grad

Attributes

adj

Sparse adjacency matrix of all particles

complete_adj

Complete sparse adjacency matrix

logit

Edge logit of shape (n_particles, n_edges)

mask_map

A reshaped index map of shape (n_vars, max_indegree) where entry (j, k) has value i, indicating which input gene is in each reshaped position for each output gene.

mean_adj

Mean sparse adjacency matrix

n_edges

prob

Edge prob of shape (n_particles, n_edges)