Stage 2: Model training
In this tutorial, we will walk through how to train the CASCADE model using the preprocessed data from stage 1.
[1]:
import networkx as nx
import pandas as pd
import scanpy as sc
from cascade.graph import acyclify, demultiplex, filter_edges, multiplex
from cascade.model import CASCADE
Read preprocessed data
[2]:
adata = sc.read_h5ad("adata.h5ad")
[3]:
scaffold = nx.read_gml("scaffold.gml.gz")
[4]:
latent_emb = pd.read_csv("latent_emb.csv.gz", index_col=0)
Build the CASCADE model
The first step is to build a CASCADE model:
[5]:
cascade = CASCADE(
vars=adata.var_names,
n_covariates=adata.obsm["covariate"].shape[1],
scaffold_graph=scaffold,
latent_data=latent_emb,
log_dir="log_dir",
)
This creates a CASCADE model under the default setting. For advanced options, visit the documentation of CASCADE to find out more about tunable hyperparameters, modules and their usage.
Run causal discovery
(Estimated time: 30 min – 1 hour, depending on computation device)
To run causal discovery using the CASCADE model, use the discover method:
[6]:
cascade.discover(adata)
cascade.save("discover.pt")
15:00:33.493 | INFO | 1484074:utils:autodevice - Using GPU [6] as computation device.
15:00:38.761 | INFO | 1484074:nn:set_empirical - Using theta coefficient = 4.150
15:00:38.763 | INFO | 1484074:nn:set_empirical - Using theta intercept = 0.819
╭────────────────────────────── cascade-reg ───────────────────────────────╮ │ │ │ Training on 1064 variables with 32264 scaffold edges and 86744 samples │ │ │ ╰───────────────────────────────── v0.4.0 ─────────────────────────────────╯
| Name | Type | Params | Mode
---------------------------------------------------
0 | scaffold | Edgewise | 129 K | train
1 | sparse | L1 | 0 | train
2 | acyc | SpecNorm | 0 | train
3 | kernel | RBF | 0 | train
4 | latent | EmbLatent | 6.3 K | train
5 | lik | NegBin | 0 | train
6 | func | Func | 18.9 M | train
| other params | n/a | 8.5 K | n/a
---------------------------------------------------
19.1 M Trainable params
0 Non-trainable params
19.1 M Total params
76.366 Total estimated model params size (MB)
16 Modules in train mode
0 Modules in eval mode
Restoring best model: log_dir/discover/lightning_logs/version_0/checkpoints/epoch=45-step=27600.ckpt.
15:45:48.805 | INFO | 1484074:model:_extrapolate_interv - Extrapolating scale and bias of 959 non-intervened variables from 105 intervened variables.
This runs CASCADE causal discovery under the default setting. For advanced options, visit the documentation of discover for more details.
The same can also be achieved using the command line interface, with the following command:
cascade discover -d adata.h5ad -m discover.pt \
--scaffold-graph scaffold.gml.gz \
--latent-data latent_emb.csv.gz [other options]
You may use
tensorboard --logdir .to monitor the training process.
Remove remaining cycles
Due to numerical limitations, some cycles may still remain in the resulting model. We further use graph utility functions to ensure directed acyclic graphs, which is required for downstream inferences.
[7]:
graph = cascade.export_causal_graph()
graph = multiplex(*[acyclify(filter_edges(g, cutoff=0.5)) for g in demultiplex(graph)])
nx.write_gml(graph, "discover.gml.gz")
The same can also be achieved using the command line interface, with the following command:
cascade acyclify -m discover.pt -g discover.gml.gz [other options]
Model tuning
(Estimated time: 15 min – 30 min, depending on computation device)
Next, we reimport the acyclified graph back into the model:
[8]:
cascade.import_causal_graph(graph)
Now we can fine tune the structural equations in the model using the tune method to adapt for removed edges during the acyclification step. It is also recommended to enable the counterfactual tuning mode, where the tuning process is specifically optimized for counterfactual prediction.
[9]:
cascade.tune(adata, tune_ctfact=True)
cascade.save("tune.pt")
15:46:28.404 | INFO | 1484074:model:tune - Pruning model...
╭────────────────────────────── cascade-reg ───────────────────────────────╮ │ │ │ Training on 1064 variables with 12294 scaffold edges and 86744 samples │ │ │ ╰───────────────────────────────── v0.4.0 ─────────────────────────────────╯
15:46:29.371 | INFO | 1484074:core:fit_stage - Number of topological generations: [68, 88, 70, 101]
| Name | Type | Params | Mode
---------------------------------------------------
0 | scaffold | Edgewise | 49.2 K | eval
1 | sparse | L1 | 0 | eval
2 | acyc | SpecNorm | 0 | eval
3 | kernel | RBF | 0 | eval
4 | latent | EmbLatent | 6.3 K | train
5 | lik | NegBin | 0 | eval
6 | func | Func | 7.9 M | train
| other params | n/a | 8.5 K | n/a
---------------------------------------------------
7.9 M Trainable params
49.2 K Non-trainable params
8.0 M Total params
31.921 Total estimated model params size (MB)
11 Modules in train mode
5 Modules in eval mode
Restoring best model: log_dir/tune/lightning_logs/version_0/checkpoints/epoch=2-step=1800.ckpt.
16:11:39.865 | INFO | 1484074:model:_extrapolate_interv - Extrapolating scale and bias of 959 non-intervened variables from 105 intervened variables.
For advanced options, visit the documentation of tune for more details.
The same can also be achieved using the command line interface, using the following command:
cascade tune -d adata.h5ad -g discover.gml.gz -m discover.pt -o tune.pt \
--tune-ctfact [other options]
Now this tuned model is ready for counterfactual prediction in stage 3 and intervention design in stage 4.