Stage 1: Data preprocessing

In this tutorial, we’ll first walk through how to prepare the datasets for use in CASCADE, using the Norman, et al. (2019) dataset as an example. This dataset contains single- and double-gene CRISPRa perturbations.

[1]:
import networkx as nx
import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.preprocessing import OneHotEncoder, StandardScaler

from cascade.data import (
    configure_dataset,
    encode_regime,
    filter_unobserved_targets,
    get_all_targets,
    get_configuration,
    neighbor_impute,
)
from cascade.graph import assemble_scaffolds

Read data

First, we need to prepare the dataset into AnnData objects. See the documentation for more details if you are unfamiliar, including how to construct AnnData objects from scratch, and how to read data in other formats (csv, mtx, loom, etc.) into AnnData objects.

Here we just load existing h5ad files, which is the native file format for AnnData. The h5ad file used in this tutorial can be downloaded from here:

[2]:
adata = sc.read_h5ad("Norman-2019.h5ad")
adata
[2]:
AnnData object with n_obs × n_vars = 86744 × 22881
    obs: 'guide_id', 'gemgroup', 'ncounts', 'knockup'
    var: 'perturbed'

Data format requirements

CASCADE requires the following data format:

  • Raw counts in adata.X;

  • Total count in adata.obs, which would be used when fitting data with the negative binomial distribution;

  • HGNC gene symbols as adata.var_names;

  • Perturbation label in adata.obs that specifies which genes are perturbed in each cell:

    • For control cells with no perturbation, the value MUST be an empty string "";

    • For cells with multiple perturbations, the perturbed genes should be comma-separated, e.g., "CEBPB,KLF1";

    • Name of perturbed genes must match those in adata.var_names.

In this case, we can verify that the expression matrix contains raw counts:

[3]:
adata.X, adata.X.data
[3]:
(<Compressed Sparse Row sparse matrix of dtype 'float32'
        with 268281595 stored elements and shape (86744, 22881)>,
 array([ 1.,  1.,  1., ..., 12.,  3., 16.], dtype=float32))

The total counts are stored as "ncounts" in adata.obs:

[4]:
adata.obs["ncounts"]
[4]:
TTGAACGAGACTCGGA      15097.0
CGTTGGGGTGTTTGTG       8551.0
GAACCTAAGTGTTAGA      10999.0
CCTTCCCTCCGTCATC      38454.0
TCCCGATGTCTCTTAT      21433.0
                       ...
TTTCCTCGTACGCACC      11991.0
TTTCCTCTCTTGCCGT      16561.0
TTTGCGCAGTCATGCT       5192.0
TTTGCGCCAGGACCCT      15704.0
TTTGCGCTCTCGCATC-1     6825.0
Name: ncounts, Length: 86744, dtype: float32

And perturbation labels are stored as "knockup" in adata.obs:

[5]:
adata.obs["knockup"]
[5]:
TTGAACGAGACTCGGA            ARID1A
CGTTGGGGTGTTTGTG            BCORL1
GAACCTAAGTGTTAGA              FOSB
CCTTCCCTCCGTCATC          KLF1,SET
TCCCGATGTCTCTTAT         BAK1,KLF1
                          ...
TTTCCTCGTACGCACC
TTTCCTCTCTTGCCGT               HK2
TTTGCGCAGTCATGCT            RHOXF2
TTTGCGCCAGGACCCT      BAK1,BCL2L11
TTTGCGCTCTCGCATC-1      CEBPB,OSR2
Name: knockup, Length: 86744, dtype: category
Categories (237, object): ['', 'AHR', 'AHR,FEV', 'AHR,KLF1', ..., 'ZBTB10', 'ZBTB25', 'ZC3HAV1', 'ZNF318']

Before any further processing, we back up the raw UMI counts in a layer called “counts”, which will be used later during model training.

[6]:
adata.layers["counts"] = adata.X.copy()

Cell and gene selection

Since CASCADE can only model perturbations in measured genes, we first filter out any perturbation that was missing from the readout. A utility function called filter_unobserved_targets is provided for this purpose.

In this case no cell was filtered:

[7]:
filter_unobserved_targets(adata, "knockup")
[7]:
View of AnnData object with n_obs × n_vars = 86744 × 22881
    obs: 'guide_id', 'gemgroup', 'ncounts', 'knockup'
    var: 'perturbed'
    layers: 'counts'

Next, we identify highly variable genes using the "seurat_v3" method, to allow the model to focus on informative genes:

[8]:
sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="seurat_v3")

Again, as CASCADE can only model perturbations in measured genes, we expand this highly variable gene set to incorporating all perturbed genes (via a utility function get_all_targets) to avoid discarding useful perturbation information:

[9]:
all_targets = get_all_targets(adata, key="knockup")
all_targets
[9]:
AHR,ARID1A,ARRDC3,ATL1,BAK1,BCL2L11,BCORL1,BPGM,CBARP,CBFA2T3,CBL,CDKN1A,CDKN1B,CDKN1C,CEBPA,CEBPB,CEBPE,CELF2,CITED1,CKS1B,CLDN6,CNN1,CNNM4,COL1A1,COL2A1,CSRNP1,DLX2,DUSP9,EGR1,ETS2,FEV,FOSB,FOXA1,FOXA3,FOXF1,FOXL2,FOXL2NB,FOXO4,GLB1L2,HES7,HK2,HNF4A,HOXA13,HOXB9,HOXC13,IER5L,IGDCC3,IKZF3,IRF1,ISL2,JUN,KIF18B,KIF2C,KLF1,KMT2A,LHX1,LYL1,MAML2,MAP2K3,MAP2K6,MAP3K21,MAP4K3,MAP4K5,MAP7D1,MAPK1,MEIS1,MIDEAS,MIDN,NCL,NIT1,OSR2,PLK4,POU3F2,PRDM1,PRTG,PTPN1,PTPN12,PTPN13,PTPN9,RHOXF2,RREB1,RUNX1T1,S1PR2,SAMD1,SET,SGK1,SLC38A2,SLC4A1,SLC6A9,SNAI1,SPI1,STIL,TBX2,TBX3,TGFBR2,TMSB4X,TP73,TSC22D1,UBASH3A,UBASH3B,ZBTB1,ZBTB10,ZBTB25,ZC3HAV1,ZNF318
[10]:
adata.var["selected"] = adata.var["highly_variable"] | adata.var_names.isin(all_targets)
adata.var["selected"].sum()
[10]:
1064

Encode intervention regime

CASCADE represents genetic perturbations as a cell-by-gene binary matrix, which can be encoded from the adata.obs["knockup"] column using the encode_regime function. The function stores the encoded regime matrix in a layer with user-specified name, here using the name "interv".

[11]:
encode_regime(adata, "interv", key="knockup")
adata.layers["interv"]
[11]:
<Compressed Sparse Row sparse matrix of dtype 'bool'
        with 108992 stored elements and shape (86744, 22881)>

Encode technical covariates

To minimize the effect of technical confounding on the causal discovery process, it is recommended to add all possible confounding factors into a covariate matrix in adata.obsm.

Here we will add the one-hot encoded batch label ("gemgroup") and log-centered total counts as the covariate:

[12]:
batch = OneHotEncoder().fit_transform(adata.obs[["gemgroup"]]).toarray()
batch
[12]:
array([[0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       ...,
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.]])
[13]:
log_ncounts = StandardScaler().fit_transform(np.log10(adata.obs[["ncounts"]]))
log_ncounts
[13]:
array([[ 0.46084866],
       [-0.8991263 ],
       [-0.29681763],
       ...,
       [-2.092782  ],
       [ 0.5551557 ],
       [-1.4385142 ]], dtype=float32)
[14]:
adata.obsm["covariate"] = np.concatenate([batch, log_ncounts], axis=1)
adata.obsm["covariate"].shape
[14]:
(86744, 9)

Data normalization

Next, we follow the standard scRNA-seq preprocessing approach in scanpy to normalize the expression matrix in adata.X. You may visit its documentation if unfamiliar.

[15]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

Now we can subset the dataset to retain the selected genes only:

[16]:
adata = adata[:, adata.var["selected"]].copy()
adata
[16]:
AnnData object with n_obs × n_vars = 86744 × 1064
    obs: 'guide_id', 'gemgroup', 'ncounts', 'knockup'
    var: 'perturbed', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'selected'
    uns: 'hvg', 'log1p'
    obsm: 'covariate'
    layers: 'counts', 'interv'

Neighbor-based imputation

Given that scRNA-seq data can be sparse, we recommend conducting a lightweight neighbor-based data imputation before model training. This is done by aggregating similar cells in the PCA space with the same perturbation. We provide a utility function called neighbor_impute for this purpose:

[17]:
sc.pp.pca(adata)
[18]:
adata = neighbor_impute(
    adata,
    k=20,
    use_rep="X_pca",
    use_batch="knockup",
    X_agg="mean",
    obs_agg={"ncounts": "sum"},
    obsm_agg={"covariate": "mean"},
    layers_agg={"counts": "sum"},
)

Note that we used the "sum" aggregation for adata.obs["ncounts"] and adata.layers["counts"], which maintains their count-based nature.

Configure dataset

Now we can use the function configure_dataset to tell CASCADE where the expression matrix, intervention regime, covariates and total counts are stored:

[19]:
configure_dataset(
    adata,
    use_regime="interv",
    use_covariate="covariate",
    use_size="ncounts",
    use_layer="counts",
)
get_configuration(adata)
[19]:
{'regime': 'interv',
 'covariate': 'covariate',
 'size': 'ncounts',
 'weight': None,
 'layer': 'counts'}

Construct scaffold graph

Next, we need to construct a scaffold graph to guide the causal discovery process.

The following 4 pre-built human gene scaffolds are available for download:

[20]:
kegg = nx.read_gml("inferred_kegg_gene_only.gml.gz")
tf_target = nx.read_gml("TF-target.gml.gz")
biogrid = nx.read_gml("biogrid.gml.gz")
corr = nx.read_gml("corr.gml.gz")

The individual scaffold components can be assembled into a hybrid scaffold using the assemble_scaffolds function, which also marginalizes these components with regard to the genes being modeled here:

[21]:
scaffold = assemble_scaffolds(corr, biogrid, tf_target, kegg, nodes=adata.var_names)
scaffold.number_of_nodes(), scaffold.number_of_edges()
14:47:12.452 | WARNING  | 1470173:graph:marginalize - 187 nodes are missing from the input graph and will be ignored.
14:47:23.102 | WARNING  | 1470173:graph:marginalize - 182 nodes are missing from the input graph and will be ignored.
14:47:28.593 | WARNING  | 1470173:graph:marginalize - 189 nodes are missing from the input graph and will be ignored.
14:47:29.730 | WARNING  | 1470173:graph:marginalize - 647 nodes are missing from the input graph and will be ignored.
[21]:
(1064, 32264)

Prepare gene function embeddings

Lastly, we also fetch relevant entries from gene embeddings pre-computed using LSI of their GO annotations, which can be downloaded here:

This will serve as the input of the interventional latent variable in CASCADE:

[22]:
latent_emb = pd.read_csv("gene2gos_lsi.csv.gz", index_col=0)
latent_emb = latent_emb.reindex(adata.var_names).dropna()
latent_emb.shape
[22]:
(866, 32)

Save processed data files

Finally, save the preprocessed data files for use in stage 2.

[23]:
adata.write("adata.h5ad", compression="gzip")
[24]:
nx.write_gml(scaffold, "scaffold.gml.gz")
[25]:
latent_emb.to_csv("latent_emb.csv.gz")

Afterwords

Described above is the minimal preprocessing for running CASCADE. Additional steps such as filtering non-perturbed cells using mixscape may also be useful depending on the data at hand.