{ "cells": [ { "cell_type": "markdown", "id": "4cc6761d-dc6e-4356-90bb-0ab585c06dc6", "metadata": {}, "source": [ "# Stage 3: Counterfactual prediction\n", "\n", "In this tutorial, we will walk through how to use the CASCADE model trained\n", "in [stage 2](training.ipynb) to conduct counterfactual inference." ] }, { "cell_type": "code", "execution_count": 1, "id": "4ce5fe21-bb6d-4f1f-913c-154ef60f767e", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:32.257466Z", "iopub.status.busy": "2025-03-20T14:30:32.257018Z", "iopub.status.idle": "2025-03-20T14:30:37.078870Z", "shell.execute_reply": "2025-03-20T14:30:37.077980Z", "shell.execute_reply.started": "2025-03-20T14:30:32.257420Z" } }, "outputs": [], "source": [ "import anndata as ad\n", "import networkx as nx\n", "import numpy as np\n", "import scanpy as sc\n", "import seaborn as sns\n", "\n", "from cascade.data import configure_dataset, encode_regime, get_configuration\n", "from cascade.graph import annotate_explanation, core_explanation_graph, prep_cytoscape\n", "from cascade.model import CASCADE\n", "from cascade.plot import set_figure_params" ] }, { "cell_type": "code", "execution_count": 2, "id": "1d698607-eb9b-444e-8839-5b9fd6a3cba3", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:37.080190Z", "iopub.status.busy": "2025-03-20T14:30:37.079728Z", "iopub.status.idle": "2025-03-20T14:30:37.089966Z", "shell.execute_reply": "2025-03-20T14:30:37.089229Z", "shell.execute_reply.started": "2025-03-20T14:30:37.080165Z" } }, "outputs": [], "source": [ "set_figure_params()" ] }, { "cell_type": "markdown", "id": "e096b394-a954-4f3a-bbaf-8aae48262871", "metadata": {}, "source": [ "## Read data and model" ] }, { "cell_type": "code", "execution_count": 3, "id": "6e965b63-6dca-45e2-83a1-a3a011d6cae4", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:37.090969Z", "iopub.status.busy": "2025-03-20T14:30:37.090758Z", "iopub.status.idle": "2025-03-20T14:30:42.044911Z", "shell.execute_reply": "2025-03-20T14:30:42.043894Z", "shell.execute_reply.started": "2025-03-20T14:30:37.090950Z" } }, "outputs": [], "source": [ "adata = sc.read_h5ad(\"adata.h5ad\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "b0fddc6e-96fb-4df8-9d1b-85a45a242116", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:42.046333Z", "iopub.status.busy": "2025-03-20T14:30:42.046114Z", "iopub.status.idle": "2025-03-20T14:30:42.576659Z", "shell.execute_reply": "2025-03-20T14:30:42.575741Z", "shell.execute_reply.started": "2025-03-20T14:30:42.046314Z" } }, "outputs": [], "source": [ "cascade = CASCADE.load(\"tune.pt\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "90e869da-4c66-4956-90b7-e53e22b160d0", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:42.578453Z", "iopub.status.busy": "2025-03-20T14:30:42.578007Z", "iopub.status.idle": "2025-03-20T14:30:46.737700Z", "shell.execute_reply": "2025-03-20T14:30:46.736849Z", "shell.execute_reply.started": "2025-03-20T14:30:42.578426Z" } }, "outputs": [], "source": [ "scaffold = nx.read_gml(\"scaffold.gml.gz\")\n", "graph = nx.read_gml(\"discover.gml.gz\")" ] }, { "cell_type": "markdown", "id": "66c8a476-c921-4c19-b5f6-5e3bdd81eb80", "metadata": {}, "source": [ "## Specify counterfactual condition\n", "\n", "Suppose we want to predict the counterfactual effect of triple gene perturbation\n", "`\"CEBPB,KLF1,MAPK1\"` for the negative control cells, we'll need to first extract\n", "some control cells, and then specify the perturbation in a column in `adata.obs`\n", "(e.g., `\"my_pert\"`), in the same comma-separated format as the `\"knockup\"` column:" ] }, { "cell_type": "code", "execution_count": 6, "id": "5a15621c-2984-46b8-8f0c-ea528ea2b92f", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:46.740134Z", "iopub.status.busy": "2025-03-20T14:30:46.739917Z", "iopub.status.idle": "2025-03-20T14:30:46.780473Z", "shell.execute_reply": "2025-03-20T14:30:46.779677Z", "shell.execute_reply.started": "2025-03-20T14:30:46.740115Z" } }, "outputs": [], "source": [ "ctrl = adata[adata.obs[\"knockup\"] == \"\"]\n", "sc.pp.subsample(ctrl, n_obs=1000)\n", "ctrl.obs[\"my_pert\"] = \"CEBPB,KLF1,MAPK1\"" ] }, { "cell_type": "markdown", "id": "15f77fe6-671f-476e-8a66-1568bd51176b", "metadata": {}, "source": [ "Then we call [encode_regime](api/cascade.data.encode_regime.rst) again to\n", "encode this counterfactual perturbation into a binary regime matrix,\n", "here in a new layer called `\"ctfact\"`:" ] }, { "cell_type": "code", "execution_count": 7, "id": "02c85109-5c4d-49e9-95fd-3e40a9f2aa89", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:46.781654Z", "iopub.status.busy": "2025-03-20T14:30:46.781336Z", "iopub.status.idle": "2025-03-20T14:30:46.856585Z", "shell.execute_reply": "2025-03-20T14:30:46.855846Z", "shell.execute_reply.started": "2025-03-20T14:30:46.781635Z" } }, "outputs": [], "source": [ "encode_regime(ctrl, \"ctfact\", key=\"my_pert\")" ] }, { "cell_type": "markdown", "id": "07489735-504f-40e2-aaa3-21247abbccd4", "metadata": {}, "source": [ "We'd also need to call [configure_dataset](api/cascade.data.configure_dataset.rst)\n", "again to let the model use this new regime:" ] }, { "cell_type": "code", "execution_count": 8, "id": "98bacafd-8794-4525-b6dc-e7a647c05323", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:46.857602Z", "iopub.status.busy": "2025-03-20T14:30:46.857397Z", "iopub.status.idle": "2025-03-20T14:30:46.864782Z", "shell.execute_reply": "2025-03-20T14:30:46.864106Z", "shell.execute_reply.started": "2025-03-20T14:30:46.857582Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m22:30:46.858\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[33m1633490\u001b[0m:\u001b[36mdata\u001b[0m:\u001b[36mconfigure_dataset\u001b[0m - \u001b[33m\u001b[1mOverwriting existing `regime` = \"interv\".\u001b[0m\n" ] }, { "data": { "text/plain": [ "{'covariate': 'covariate',\n", " 'layer': 'counts',\n", " 'regime': 'ctfact',\n", " 'size': 'ncounts'}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "configure_dataset(ctrl, use_regime=\"ctfact\")\n", "get_configuration(ctrl)" ] }, { "cell_type": "markdown", "id": "73347548-94b4-42ae-a037-f13f2f32152f", "metadata": {}, "source": [ "## Run counterfactual prediction\n", "\n", "Now we use the `counterfactual` method to perform counterfactual prediction\n", "with this newly specified perturbation:" ] }, { "cell_type": "code", "execution_count": 9, "id": "6023871a-1e3a-4174-b5f8-c248623be190", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:46.865572Z", "iopub.status.busy": "2025-03-20T14:30:46.865396Z", "iopub.status.idle": "2025-03-20T14:30:49.578851Z", "shell.execute_reply": "2025-03-20T14:30:49.578063Z", "shell.execute_reply.started": "2025-03-20T14:30:46.865556Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32m22:30:46.885\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[33m1633490\u001b[0m:\u001b[36mutils\u001b[0m:\u001b[36mautodevice\u001b[0m - \u001b[1mUsing GPU [3] as computation device.\u001b[0m\n", "\u001b[32m22:30:47.051\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[33m1633490\u001b[0m:\u001b[36mcore\u001b[0m:\u001b[36mpredict_mode\u001b[0m - \u001b[1mNumber of topological generations: [68, 88, 70, 101]\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "371da51767e44dfb817024b79cdc7ab7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: | …" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ctfact = cascade.counterfactual(ctrl, sample=True)" ] }, { "cell_type": "markdown", "id": "f525a369-0fd1-4104-ae27-663f213604fd", "metadata": {}, "source": [ "Here we specified `sample=True` to make the model output random samples from the\n", "counterfactual negative binomial distribution, which would better represent\n", "the distribution than a simple mean.\n", "\n", "The prediction will be saved in both `ctfact.X` and `ctfact.layers[\"X_ctfact\"]`,\n", "where `ctfact.X` is the average prediction across SVGD particles, and\n", "`ctfact.layers[\"X_ctfact\"]` contains the per-particle predictions with shape\n", "`(n_obs, n_vars, n_particles)`. Note that both of these are in raw count scale." ] }, { "cell_type": "code", "execution_count": 10, "id": "4309b965-b1cd-48fc-acad-2751e9b92e46", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:49.580781Z", "iopub.status.busy": "2025-03-20T14:30:49.580156Z", "iopub.status.idle": "2025-03-20T14:30:49.587641Z", "shell.execute_reply": "2025-03-20T14:30:49.587044Z", "shell.execute_reply.started": "2025-03-20T14:30:49.580746Z" } }, "outputs": [ { "data": { "text/plain": [ "array([[ 0. , 6.5 , 16.5 , ..., 0. , 0. , 72.25],\n", " [ 1. , 10.5 , 23.25, ..., 0. , 0. , 140. ],\n", " [ 0. , 11.75, 32.75, ..., 0. , 0. , 157.5 ],\n", " ...,\n", " [ 0.25, 6. , 19.5 , ..., 0. , 0. , 59.25],\n", " [ 0.25, 9. , 19.5 , ..., 0. , 0. , 90.75],\n", " [ 0.5 , 4.5 , 18. , ..., 0. , 0. , 77.25]],\n", " dtype=float32)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ctfact.X" ] }, { "cell_type": "code", "execution_count": 11, "id": "83835396-0e73-499c-9c74-4210f343b8af", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:49.588578Z", "iopub.status.busy": "2025-03-20T14:30:49.588346Z", "iopub.status.idle": "2025-03-20T14:30:49.594830Z", "shell.execute_reply": "2025-03-20T14:30:49.593951Z", "shell.execute_reply.started": "2025-03-20T14:30:49.588555Z" } }, "outputs": [ { "data": { "text/plain": [ "(1000, 1064, 4)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ctfact.layers[\"X_ctfact\"].shape" ] }, { "attachments": {}, "cell_type": "markdown", "id": "fb8c7a2a-4503-460d-9d06-85f9bdfea91b", "metadata": {}, "source": [ "> For counterfactual prediction of [CASCADE designs](design.ipynb), you would\n", "> also need to specify the `design` argument to the `counterfactual` method.\n", "\n", "Please visit the documentation of\n", "[counterfactual](api/cascade.model.CASCADE.counterfactual.rst)\n", "for more details.\n", "\n", "The same can also be achieved using the\n", "[command line interface](cli.rst#counterfactual-deduction),\n", "with the following command:\n", "\n", "```sh\n", "cascade counterfactual -d ctrl.h5ad -m tune.pt -p ctfact.h5ad [other options]\n", "```\n", "\n", "## Counterfactual differential expression comparison\n", "\n", "To check for counterfactual effects, we are expected to compare the predicted\n", "dataset (`ctfact`) with the input dataset (`ctrl`). However, to avoid artifacts\n", "caused by model prediction biases, it is recommended to compare the predicted\n", "dataset (`ctfact`) with a \"nil prediction\", i.e., model prediction with the\n", "original perturbation labels.\n", "\n", "Here, we can go back to use the `\"interv\"` regime to obtain the \"nil prediction\":" ] }, { "cell_type": "code", "execution_count": 12, "id": "a23a0494-93fd-495f-9c8a-279baf5775ea", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:49.595984Z", "iopub.status.busy": "2025-03-20T14:30:49.595741Z", "iopub.status.idle": "2025-03-20T14:30:49.603594Z", "shell.execute_reply": "2025-03-20T14:30:49.602734Z", "shell.execute_reply.started": "2025-03-20T14:30:49.595961Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m22:30:49.596\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[33m1633490\u001b[0m:\u001b[36mdata\u001b[0m:\u001b[36mconfigure_dataset\u001b[0m - \u001b[33m\u001b[1mOverwriting existing `regime` = \"ctfact\".\u001b[0m\n" ] }, { "data": { "text/plain": [ "{'covariate': 'covariate',\n", " 'layer': 'counts',\n", " 'regime': 'interv',\n", " 'size': 'ncounts'}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "configure_dataset(ctrl, use_regime=\"interv\")\n", "get_configuration(ctrl)" ] }, { "cell_type": "code", "execution_count": 13, "id": "7a5fff11-85c0-4da5-8c77-2f8d87eff28d", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:49.604760Z", "iopub.status.busy": "2025-03-20T14:30:49.604533Z", "iopub.status.idle": "2025-03-20T14:30:51.809545Z", "shell.execute_reply": "2025-03-20T14:30:51.808543Z", "shell.execute_reply.started": "2025-03-20T14:30:49.604739Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32m22:30:49.790\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[33m1633490\u001b[0m:\u001b[36mcore\u001b[0m:\u001b[36mpredict_mode\u001b[0m - \u001b[1mNumber of topological generations: [68, 88, 70, 101]\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7a20c015a98d4bf4abd61a34c37275f3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: | …" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "nil = cascade.counterfactual(ctrl, sample=True)" ] }, { "cell_type": "markdown", "id": "a9267e9c-bf1c-48b7-86c1-aa441754c549", "metadata": {}, "source": [ "Now we combine and log-normalize both predictions to perform differential\n", "expression analysis:" ] }, { "cell_type": "code", "execution_count": 14, "id": "dd55957f-c075-4900-8d3a-133faab1c4c9", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:51.811022Z", "iopub.status.busy": "2025-03-20T14:30:51.810783Z", "iopub.status.idle": "2025-03-20T14:30:51.868684Z", "shell.execute_reply": "2025-03-20T14:30:51.867795Z", "shell.execute_reply.started": "2025-03-20T14:30:51.810999Z" } }, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 2000 × 1064\n", " obs: 'guide_id', 'gemgroup', 'ncounts', 'knockup', 'my_pert', 'role'\n", " obsm: 'X_pca', 'covariate'\n", " layers: 'counts', 'interv', 'ctfact', 'X_ctfact'" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "combined = ad.concat({\"nil\": nil, \"ctfact\": ctfact}, label=\"role\", index_unique=\"-\")\n", "combined.X = np.log1p(combined.X * (1e4 / combined.obs[[\"ncounts\"]].to_numpy()))\n", "combined" ] }, { "cell_type": "code", "execution_count": 15, "id": "6b471dd7-d1ab-442b-91e3-70889a766c6a", "metadata": { "execution": { "iopub.execute_input": "2025-03-20T14:30:51.869927Z", "iopub.status.busy": "2025-03-20T14:30:51.869701Z", "iopub.status.idle": "2025-03-20T14:30:51.986242Z", "shell.execute_reply": "2025-03-20T14:30:51.985352Z", "shell.execute_reply.started": "2025-03-20T14:30:51.869907Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/rd1/user/caozj/CASCADE/conda/lib/python3.11/site-packages/pandas/core/arraylike.py:399: RuntimeWarning: divide by zero encountered in log10\n", " result = getattr(ufunc, method)(*inputs, **kwargs)\n" ] }, { "data": { "text/html": [ "
| \n", " | names | \n", "scores | \n", "logfoldchanges | \n", "pvals | \n", "pvals_adj | \n", "pct_nz_group | \n", "-logfdr | \n", "
|---|---|---|---|---|---|---|---|
| 0 | \n", "KLF1 | \n", "207.035004 | \n", "1.951955 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "350.0 | \n", "
| 1 | \n", "PNMT | \n", "171.383087 | \n", "2.364175 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "350.0 | \n", "
| 2 | \n", "CEBPB | \n", "152.714493 | \n", "2.131128 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "350.0 | \n", "
| 3 | \n", "TMSB10 | \n", "144.004883 | \n", "1.232379 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "350.0 | \n", "
| 4 | \n", "MAPK1 | \n", "134.214584 | \n", "1.047204 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "350.0 | \n", "