{ "cells": [ { "cell_type": "markdown", "id": "4576ce75-a8b2-4687-b65f-70d489de0000", "metadata": {}, "source": [ "# Stage 2: Model training\n", "\n", "In this tutorial, we will walk through how to train the CASCADE model using the\n", "preprocessed data from [stage 1](preprocessing.ipynb)." ] }, { "cell_type": "code", "execution_count": 1, "id": "a37dd09c-a6ab-4521-82d4-f0ef3e3dd0cb", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T07:00:19.647564Z", "iopub.status.busy": "2025-03-20T07:00:19.646930Z", "iopub.status.idle": "2025-03-20T07:00:24.530187Z", "shell.execute_reply": "2025-03-20T07:00:24.529097Z", "shell.execute_reply.started": "2025-03-20T07:00:19.647513Z" } }, "outputs": [], "source": [ "import networkx as nx\n", "import pandas as pd\n", "import scanpy as sc\n", "\n", "from cascade.graph import acyclify, demultiplex, filter_edges, multiplex\n", "from cascade.model import CASCADE" ] }, { "cell_type": "markdown", "id": "2907ebbf-45df-4b3f-8134-5060516f19f1", "metadata": {}, "source": [ "## Read preprocessed data" ] }, { "cell_type": "code", "execution_count": 2, "id": "ab477914-708d-4097-9af7-82738d93a265", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T07:00:24.531745Z", "iopub.status.busy": "2025-03-20T07:00:24.531295Z", "iopub.status.idle": "2025-03-20T07:00:29.457369Z", "shell.execute_reply": "2025-03-20T07:00:29.456291Z", "shell.execute_reply.started": "2025-03-20T07:00:24.531721Z" } }, "outputs": [], "source": [ "adata = sc.read_h5ad(\"adata.h5ad\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "38c8b9ec-69c4-40cc-9f94-2d4984a96434", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T07:00:29.459200Z", "iopub.status.busy": "2025-03-20T07:00:29.458951Z", "iopub.status.idle": "2025-03-20T07:00:32.308398Z", "shell.execute_reply": "2025-03-20T07:00:32.307557Z", "shell.execute_reply.started": "2025-03-20T07:00:29.459179Z" } }, "outputs": [], "source": [ "scaffold = nx.read_gml(\"scaffold.gml.gz\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "3800ef30-07b5-4207-a595-92676e5f8cfb", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T07:00:32.312427Z", "iopub.status.busy": "2025-03-20T07:00:32.312246Z", "iopub.status.idle": "2025-03-20T07:00:32.332232Z", "shell.execute_reply": "2025-03-20T07:00:32.331483Z", "shell.execute_reply.started": "2025-03-20T07:00:32.312408Z" } }, "outputs": [], "source": [ "latent_emb = pd.read_csv(\"latent_emb.csv.gz\", index_col=0)" ] }, { "cell_type": "markdown", "id": "ec1b8388-ad1e-44d0-ba05-dd10359faed8", "metadata": {}, "source": [ "## Build the CASCADE model\n", "\n", "The first step is to build a CASCADE model:" ] }, { "cell_type": "code", "execution_count": 5, "id": "494a65cb-e582-4dca-8501-3f6817abf23c", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T07:00:32.333135Z", "iopub.status.busy": "2025-03-20T07:00:32.332955Z", "iopub.status.idle": "2025-03-20T07:00:33.467268Z", "shell.execute_reply": "2025-03-20T07:00:33.466537Z", "shell.execute_reply.started": "2025-03-20T07:00:32.333118Z" } }, "outputs": [], "source": [ "cascade = CASCADE(\n", " vars=adata.var_names,\n", " n_covariates=adata.obsm[\"covariate\"].shape[1],\n", " scaffold_graph=scaffold,\n", " latent_data=latent_emb,\n", " log_dir=\"log_dir\",\n", ")" ] }, { "cell_type": "markdown", "id": "3c65ded7-71e6-4853-8110-1036614ff48f", "metadata": {}, "source": [ "This creates a CASCADE model under the default setting. For advanced options,\n", "visit the documentation of [CASCADE](api/cascade.model.CASCADE.rst) to find out\n", "more about tunable hyperparameters, modules and their usage.\n", "\n", "## Run causal discovery\n", "\n", "> (Estimated time: 30 min – 1 hour, depending on computation device)\n", "\n", "To run causal discovery using the CASCADE model, use the `discover` method:" ] }, { "cell_type": "code", "execution_count": 6, "id": "a0698126-e973-4b65-895d-f65f962db29c", "metadata": { "editable": true, "execution": { "iopub.execute_input": "2025-03-20T07:00:33.468346Z", "iopub.status.busy": "2025-03-20T07:00:33.468080Z", "iopub.status.idle": "2025-03-20T07:45:49.071616Z", "shell.execute_reply": "2025-03-20T07:45:49.071043Z", "shell.execute_reply.started": "2025-03-20T07:00:33.468321Z" }, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32m15:00:33.493\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[33m1484074\u001b[0m:\u001b[36mutils\u001b[0m:\u001b[36mautodevice\u001b[0m - \u001b[1mUsing GPU [6] as computation device.\u001b[0m\n", "\u001b[32m15:00:38.761\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[33m1484074\u001b[0m:\u001b[36mnn\u001b[0m:\u001b[36mset_empirical\u001b[0m - \u001b[1mUsing theta coefficient = 4.150\u001b[0m\n", "\u001b[32m15:00:38.763\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[33m1484074\u001b[0m:\u001b[36mnn\u001b[0m:\u001b[36mset_empirical\u001b[0m - \u001b[1mUsing theta intercept = 0.819\u001b[0m\n" ] }, { "data": { "text/html": [ "
╭────────────────────────────── cascade-reg ───────────────────────────────╮\n",
"│ │\n",
"│ Training on 1064 variables with 32264 scaffold edges and 86744 samples │\n",
"│ │\n",
"╰───────────────────────────────── v0.4.0 ─────────────────────────────────╯\n",
"\n"
],
"text/plain": [
"╭────────────────────────────── cascade-reg ───────────────────────────────╮\n",
"│ │\n",
"│ Training on \u001b[1;35m1064\u001b[0m variables with \u001b[1;35m32264\u001b[0m scaffold edges and \u001b[1;35m86744\u001b[0m samples │\n",
"│ │\n",
"╰───────────────────────────────── v0.4.0 ─────────────────────────────────╯\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
" | Name | Type | Params | Mode \n",
"---------------------------------------------------\n",
"0 | scaffold | Edgewise | 129 K | train\n",
"1 | sparse | L1 | 0 | train\n",
"2 | acyc | SpecNorm | 0 | train\n",
"3 | kernel | RBF | 0 | train\n",
"4 | latent | EmbLatent | 6.3 K | train\n",
"5 | lik | NegBin | 0 | train\n",
"6 | func | Func | 18.9 M | train\n",
" | other params | n/a | 8.5 K | n/a \n",
"---------------------------------------------------\n",
"19.1 M Trainable params\n",
"0 Non-trainable params\n",
"19.1 M Total params\n",
"76.366 Total estimated model params size (MB)\n",
"16 Modules in train mode\n",
"0 Modules in eval mode\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a46d97a6a7094e5a8589d27732a45fc1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: | …"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Restoring best model: log_dir/discover/lightning_logs/version_0/checkpoints/epoch=45-step=27600.ckpt.\n", "\n" ], "text/plain": [ "Restoring best model: log_dir/discover/lightning_logs/version_0/checkpoints/\u001b[33mepoch\u001b[0m=\u001b[1;36m45\u001b[0m-\u001b[33mstep\u001b[0m=\u001b[1;36m27600\u001b[0m\u001b[1;36m.\u001b[0mckpt.\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32m15:45:48.805\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[33m1484074\u001b[0m:\u001b[36mmodel\u001b[0m:\u001b[36m_extrapolate_interv\u001b[0m - \u001b[1mExtrapolating scale and bias of 959 non-intervened variables from 105 intervened variables.\u001b[0m\n" ] } ], "source": [ "cascade.discover(adata)\n", "cascade.save(\"discover.pt\")" ] }, { "cell_type": "markdown", "id": "9939c14d-e141-49b1-b2a9-093a4557ce1d", "metadata": {}, "source": [ "This runs CASCADE causal discovery under the default setting. For advanced\n", "options, visit the documentation of\n", "[discover](api/cascade.model.CASCADE.discover.rst) for more details.\n", "\n", "The same can also be achieved using the\n", "[command line interface](cli.rst#causal-discovery),\n", "with the following command:\n", "\n", "```sh\n", "cascade discover -d adata.h5ad -m discover.pt \\\n", " --scaffold-graph scaffold.gml.gz \\\n", " --latent-data latent_emb.csv.gz [other options]\n", "```\n", "\n", "> You may use `tensorboard --logdir .` to monitor the training process.\n", "\n", "## Remove remaining cycles\n", "\n", "Due to numerical limitations, some cycles may still remain in the resulting model.\n", "We further use graph utility functions to ensure directed acyclic graphs, which\n", "is required for downstream inferences." ] }, { "cell_type": "code", "execution_count": 7, "id": "e038c332-7698-4780-8966-13122f55aba3", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T07:46:08.722722Z", "iopub.status.busy": "2025-03-20T07:46:08.722451Z", "iopub.status.idle": "2025-03-20T07:46:22.758086Z", "shell.execute_reply": "2025-03-20T07:46:22.757166Z", "shell.execute_reply.started": "2025-03-20T07:46:08.722695Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c3e9965ee5354338b03b8fe64d3bb563", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/32264 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "34fa3c0a9649436496fa2f9ec88f7896", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/5329 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ca8b1f9947ad41f99bdaa69a1b203c49", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/5928 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5005b4ca41fa4412bb988867a6afa736", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/6096 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fc9991d1648c4b5dba524fad3ffe95a5", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/5712 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "28ba8f27784b4e2ab90110572d352125", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/12294 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "graph = cascade.export_causal_graph()\n", "graph = multiplex(*[acyclify(filter_edges(g, cutoff=0.5)) for g in demultiplex(graph)])\n", "nx.write_gml(graph, \"discover.gml.gz\")" ] }, { "cell_type": "markdown", "id": "f61051d7-64eb-4a5e-bc51-e3ec36bc652f", "metadata": {}, "source": [ "The same can also be achieved using the\n", "[command line interface](cli.rst#graph-acyclification),\n", "with the following command:\n", "\n", "```sh\n", "cascade acyclify -m discover.pt -g discover.gml.gz [other options]\n", "```\n", "\n", "## Model tuning\n", "\n", "> (Estimated time: 15 min – 30 min, depending on computation device)\n", "\n", "Next, we reimport the acyclified graph back into the model:" ] }, { "cell_type": "code", "execution_count": 8, "id": "0cb819c0-fe99-449b-88d1-d89adbdc9a3a", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T07:46:24.632280Z", "iopub.status.busy": "2025-03-20T07:46:24.631886Z", "iopub.status.idle": "2025-03-20T07:46:24.692344Z", "shell.execute_reply": "2025-03-20T07:46:24.691774Z", "shell.execute_reply.started": "2025-03-20T07:46:24.632255Z" } }, "outputs": [], "source": [ "cascade.import_causal_graph(graph)" ] }, { "cell_type": "markdown", "id": "70736d32-790a-429d-9e99-cc5af8077b27", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "Now we can fine tune the structural equations in the model using the `tune`\n", "method to adapt for removed edges during the acyclification step. It is also\n", "recommended to enable the counterfactual tuning mode, where the tuning process\n", "is specifically optimized for counterfactual prediction." ] }, { "cell_type": "code", "execution_count": 9, "id": "8d2dc242-14d9-4668-9e53-7c7ba9098600", "metadata": { "editable": true, "execution": { "iopub.execute_input": "2025-03-20T07:46:28.386524Z", "iopub.status.busy": "2025-03-20T07:46:28.386273Z", "iopub.status.idle": "2025-03-20T08:11:40.070653Z", "shell.execute_reply": "2025-03-20T08:11:40.069627Z", "shell.execute_reply.started": "2025-03-20T07:46:28.386504Z" }, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32m15:46:28.404\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[33m1484074\u001b[0m:\u001b[36mmodel\u001b[0m:\u001b[36mtune\u001b[0m - \u001b[1mPruning model...\u001b[0m\n" ] }, { "data": { "text/html": [ "
╭────────────────────────────── cascade-reg ───────────────────────────────╮\n",
"│ │\n",
"│ Training on 1064 variables with 12294 scaffold edges and 86744 samples │\n",
"│ │\n",
"╰───────────────────────────────── v0.4.0 ─────────────────────────────────╯\n",
"\n"
],
"text/plain": [
"╭────────────────────────────── cascade-reg ───────────────────────────────╮\n",
"│ │\n",
"│ Training on \u001b[1;35m1064\u001b[0m variables with \u001b[1;35m12294\u001b[0m scaffold edges and \u001b[1;35m86744\u001b[0m samples │\n",
"│ │\n",
"╰───────────────────────────────── v0.4.0 ─────────────────────────────────╯\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32m15:46:29.371\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[33m1484074\u001b[0m:\u001b[36mcore\u001b[0m:\u001b[36mfit_stage\u001b[0m - \u001b[1mNumber of topological generations: [68, 88, 70, 101]\u001b[0m\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
" | Name | Type | Params | Mode \n",
"---------------------------------------------------\n",
"0 | scaffold | Edgewise | 49.2 K | eval \n",
"1 | sparse | L1 | 0 | eval \n",
"2 | acyc | SpecNorm | 0 | eval \n",
"3 | kernel | RBF | 0 | eval \n",
"4 | latent | EmbLatent | 6.3 K | train\n",
"5 | lik | NegBin | 0 | eval \n",
"6 | func | Func | 7.9 M | train\n",
" | other params | n/a | 8.5 K | n/a \n",
"---------------------------------------------------\n",
"7.9 M Trainable params\n",
"49.2 K Non-trainable params\n",
"8.0 M Total params\n",
"31.921 Total estimated model params size (MB)\n",
"11 Modules in train mode\n",
"5 Modules in eval mode\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a38ffc72bd1c49abbe879f03e41858dd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: | …"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Restoring best model: log_dir/tune/lightning_logs/version_0/checkpoints/epoch=2-step=1800.ckpt.\n", "\n" ], "text/plain": [ "Restoring best model: log_dir/tune/lightning_logs/version_0/checkpoints/\u001b[33mepoch\u001b[0m=\u001b[1;36m2\u001b[0m-\u001b[33mstep\u001b[0m=\u001b[1;36m1800\u001b[0m\u001b[1;36m.\u001b[0mckpt.\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32m16:11:39.865\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[33m1484074\u001b[0m:\u001b[36mmodel\u001b[0m:\u001b[36m_extrapolate_interv\u001b[0m - \u001b[1mExtrapolating scale and bias of 959 non-intervened variables from 105 intervened variables.\u001b[0m\n" ] } ], "source": [ "cascade.tune(adata, tune_ctfact=True)\n", "cascade.save(\"tune.pt\")" ] }, { "cell_type": "markdown", "id": "2022cb97-32d4-4795-a5c5-f657a72dbaa1", "metadata": {}, "source": [ "For advanced options, visit the documentation of\n", "[tune](api/cascade.model.CASCADE.tune.rst) for more details.\n", "\n", "The same can also be achieved using the\n", "[command line interface](cli.rst#model-tuning),\n", "using the following command:\n", "\n", "```sh\n", "cascade tune -d adata.h5ad -g discover.gml.gz -m discover.pt -o tune.pt \\\n", " --tune-ctfact [other options]\n", "```\n", "\n", "Now this tuned model is ready for counterfactual prediction in\n", "[stage 3](counterfactual.ipynb) and intervention design in\n", "[stage 4](design.ipynb)." ] } ], "metadata": { "jupytext": { "formats": "ipynb,py:percent" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 5 }