cascade.core

Core pytorch lightning module and training callbacks for the CASCADE model

Classes

CausalNetwork

Causal discovery neural network

DiscoverScheduler

Hyperparameter scheduler for causal discovery

FitStage

Model fitting stage

LogAdj

Logging mode of the adjacency matrix in tensorboard

ModelCheckpoint

Custom model checkpoint callback that can be configured to skip saving the model once

PredictMode

Model prediction mode

PredictionWriter

Custom prediction writer to enable multi-device prediction