cascade.nn

Neural network utilities

Functions

copy_like

Copy tensor device, dtype and data from a source tensor to a target tensor in place

gumbel_sigmoid

Straight-through Gumbel sigmoid sampler

mean_squared_error

Compute the mean squared error along a specified dimension

multi_rbf

RBF kernel with support for multiplex dims

multi_trace

Compute matrix trace with support for multiplex dims

Classes

AcycPrior

Prior that enforces acyclicity constraint

AttnPool

Attention-based pooling layer to combine multiple intervention embeddings

Bilinear

Bilinearly parameterized edge logits

Edgewise

Edgewise parameterized edge logits

EmbLatent

Intervention latent module encoding from fixed embeddings

Func

Structural equation with covariates

GCNLatent

Intervention latent module encoding from a graph

IntervDesign

Intervention design module

Kernel

Abstract class for kernels

KroneckerDelta

Kronecker delta kernel

L1

L1 penalized log prior probability

Latent

Interventional latent module

Likelihood

Abstract class for causal distributions

LogDet

Log-determinant penalized log prior probability

Module

Abstract module class supporting parameter freezing, decayed / non-decayed parameter iteration, and cached property clearing

ModuleList

A module list with the Module capabilities

MultiLinear

Linear layer with support for multi-dims

NegBin

Negative binomial causal distribution

NilLatent

Nil interventional latent module that always outputs the standard normal

Normal

Normal causal distribution

Prior

Compute unnormalized negative log prior probability of a scaffold graph

RBF

Radial basis function kernel

Scaffold

Abstract graph scaffold

ScaleFree

Scale-free penalized log prior probability

SparsePrior

Prior that encourages sparsity

SpecNorm

Spectral norm penalized log prior probability

TrExp

Tr-Exp penalized log prior probability