Source code for cascade.metrics

r"""
Metrics for evaluating the accuracy of inferred causal structures
"""

from collections.abc import Callable
from functools import lru_cache, reduce, wraps
from operator import or_
from statistics import mean

import networkx as nx
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.stats
from anndata import AnnData
from loguru import logger
from scipy.sparse import issparse
from sklearn.metrics import (
    accuracy_score,
    auc,
    average_precision_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)
from tqdm.auto import tqdm

from .data import Targets, aggregate_obs
from .graph import acyclify, demultiplex, filter_edges, multiplex_num
from .typing import SimpleGraph
from .utils import densify

# ------------------------- Causal discovery metrics ---------------------------


TrueDiscMetricFn = Callable[[nx.DiGraph, nx.DiGraph, ...], np.floating]  # type: ignore
RespDiscMetricFn = Callable[[nx.DiGraph, ...], np.floating]  # type: ignore


def _multiplex_compat_true_metric(f: TrueDiscMetricFn) -> TrueDiscMetricFn:
    r"""
    Wraps a metric function to be compatible with multiplexed graphs, in which
    case the returned value is the average over all multiplexed graphs.
    """

    @wraps(f)
    def wrapped(true: nx.DiGraph, pred: nx.DiGraph, **kwargs) -> np.floating:
        if multiplex_num(pred):
            return mean(f(true, p, **kwargs) for p in demultiplex(pred))
        return f(true, pred, **kwargs)

    return wrapped


def _multiplex_compat_resp_metric(f: RespDiscMetricFn) -> RespDiscMetricFn:
    r"""
    Wraps a metric function to be compatible with multiplexed graphs, in which
    case the returned value is the average over all multiplexed graphs.
    """

    @wraps(f)
    def wrapped(pred: nx.DiGraph, **kwargs) -> np.floating:
        if multiplex_num(pred):
            return mean(f(p, **kwargs) for p in demultiplex(pred))
        return f(pred, **kwargs)

    return wrapped


[docs] @lru_cache def cmp_true_pred( true: nx.DiGraph, pred: nx.DiGraph, edge_attr: str | None = None, scaffold: SimpleGraph | None = None, ) -> pd.DataFrame: r""" Compare the true and predicted causal graphs in a long-form data frame Parameters ---------- true True causal graph pred Predicted causal graph edge_attr Prediction edge attribute (edges are taken as binary if None) scaffold Scaffold graph Returns ------- Long-form comparison data frame .. caution:: The cache is **UNSAFE** from inplace graph manipulations. """ if scaffold is None: scaffold = nx.complete_graph(true.nodes | pred.nodes) if not nx.is_directed(scaffold): scaffold = nx.DiGraph(scaffold) df = nx.to_pandas_edgelist(scaffold) df["true"] = [true.has_edge(u, v) for u, v in zip(df["source"], df["target"])] if edge_attr is None: df["pred"] = [pred.has_edge(u, v) for u, v in zip(df["source"], df["target"])] else: pred = nx.get_edge_attributes(pred, edge_attr) df["pred"] = [pred.get((u, v), 0.0) for u, v in zip(df["source"], df["target"])] return df
[docs] @lru_cache def optimal_cutoff( true: nx.DiGraph, pred: nx.DiGraph, edge_attr: str = "weight", scaffold: SimpleGraph | None = None, ) -> np.floating: r""" Obtain the optimal binary classification cutoff Parameters ---------- true True causal graph pred Predicted causal graph edge_attr Prediction edge attribute scaffold Scaffold graph Returns ------- Optimal binary classification cutoff """ cmp = cmp_true_pred(true, pred, edge_attr=edge_attr, scaffold=scaffold) fpr, tpr, thresholds = roc_curve(cmp["true"], cmp["pred"]) dist = np.sqrt((fpr - 0) ** 2 + (tpr - 1) ** 2) # Distance to top-left corner return thresholds[min(np.argmin(dist) + 1, dist.size - 1)]
[docs] @_multiplex_compat_true_metric def disc_acc( true: nx.DiGraph, pred: nx.DiGraph, edge_attr: str = "weight", scaffold: SimpleGraph | None = None, cutoff: float | None = None, ) -> np.floating: r""" Accuracy of the predicted causal graph .. note:: See :func:`cmp_true_pred` for argument descriptions. """ if cutoff is None: cutoff = optimal_cutoff(true, pred, edge_attr=edge_attr, scaffold=scaffold) pred = filter_edges(pred, cutoff=cutoff) cmp = cmp_true_pred(true, pred, scaffold=scaffold) return np.float64(accuracy_score(cmp["true"], cmp["pred"])) # sklearn #27339
[docs] @_multiplex_compat_true_metric def disc_prec( true: nx.DiGraph, pred: nx.DiGraph, edge_attr: str = "weight", scaffold: SimpleGraph | None = None, cutoff: float | None = None, ) -> np.floating: r""" Precision of the predicted causal graph .. note:: See :func:`cmp_true_pred` for argument descriptions. """ if cutoff is None: cutoff = optimal_cutoff(true, pred, edge_attr=edge_attr, scaffold=scaffold) pred = filter_edges(pred, cutoff=cutoff) cmp = cmp_true_pred(true, pred, scaffold=scaffold) return precision_score(cmp["true"], cmp["pred"], zero_division=0.0)
[docs] @_multiplex_compat_true_metric def disc_recall( true: nx.DiGraph, pred: nx.DiGraph, edge_attr: str = "weight", scaffold: SimpleGraph | None = None, cutoff: float | None = None, ) -> np.floating: r""" Recall of the predicted causal graph .. note:: See :func:`cmp_true_pred` for argument descriptions. """ if cutoff is None: cutoff = optimal_cutoff(true, pred, edge_attr=edge_attr, scaffold=scaffold) pred = filter_edges(pred, cutoff=cutoff) cmp = cmp_true_pred(true, pred, scaffold=scaffold) return recall_score(cmp["true"], cmp["pred"])
[docs] @_multiplex_compat_true_metric def disc_f1( true: nx.DiGraph, pred: nx.DiGraph, edge_attr: str = "weight", scaffold: SimpleGraph | None = None, cutoff: float | None = None, ) -> np.floating: r""" F1 score of the predicted causal graph .. note:: See :func:`cmp_true_pred` for argument descriptions. """ if cutoff is None: cutoff = optimal_cutoff(true, pred, edge_attr=edge_attr, scaffold=scaffold) pred = filter_edges(pred, cutoff=cutoff) cmp = cmp_true_pred(true, pred, scaffold=scaffold) return f1_score(cmp["true"], cmp["pred"])
[docs] @_multiplex_compat_true_metric def disc_auroc( true: nx.DiGraph, pred: nx.DiGraph, edge_attr: str = "weight", scaffold: SimpleGraph | None = None, ) -> np.floating: r""" Area under ROC curve of the predicted causal graph .. note:: See :func:`cmp_true_pred` for argument descriptions. """ cmp = cmp_true_pred(true, pred, edge_attr=edge_attr, scaffold=scaffold) return roc_auc_score(cmp["true"], cmp["pred"])
[docs] @_multiplex_compat_true_metric def disc_ap( true: nx.DiGraph, pred: nx.DiGraph, edge_attr: str = "weight", scaffold: SimpleGraph | None = None, ) -> np.floating: r""" Average precision of the predicted causal graph .. note:: See :func:`cmp_true_pred` for argument descriptions. """ cmp = cmp_true_pred(true, pred, edge_attr=edge_attr, scaffold=scaffold) return average_precision_score(cmp["true"], cmp["pred"])
[docs] @_multiplex_compat_true_metric def disc_shd( true: nx.DiGraph, pred: nx.DiGraph, edge_attr: str = "weight", scaffold: SimpleGraph | None = None, cutoff: float | None = None, ) -> np.floating: r""" Structural hamming distance between the true and predicted causal graph .. note:: See :func:`cmp_true_pred` for argument descriptions. """ if cutoff is None: cutoff = optimal_cutoff(true, pred, edge_attr=edge_attr, scaffold=scaffold) pred = filter_edges(pred, cutoff=cutoff) cmp = cmp_true_pred(true, pred, edge_attr=edge_attr, scaffold=scaffold) return (cmp["true"] - cmp["pred"]).abs().sum()
[docs] @_multiplex_compat_true_metric def disc_sid( true: nx.DiGraph, pred: nx.DiGraph, edge_attr: str = "weight", scaffold: SimpleGraph | None = None, cutoff: float | None = None, ) -> np.floating: r""" Structural interventional distance between the true and predicted causal graph .. note:: See :func:`cmp_true_pred` for argument descriptions. """ from .ri import structIntervDist if cutoff is None: cutoff = optimal_cutoff(true, pred, edge_attr=edge_attr, scaffold=scaffold) pred = filter_edges(pred, cutoff=cutoff) if not nx.is_directed_acyclic_graph(true): raise ValueError("The true causal graph must be a DAG") if not nx.is_directed_acyclic_graph(pred): logger.warning("Acyclifying the predicted causal graph.") pred = acyclify(pred) nodes = sorted(true.nodes | pred.nodes) true = nx.to_scipy_sparse_array(true, nodelist=nodes, weight=None, format="coo") pred = nx.to_scipy_sparse_array(pred, nodelist=nodes, weight=None, format="coo") return np.float64(structIntervDist(true, pred))
[docs] def annot_resp(pred: nx.DiGraph, adata: AnnData, interv_key: str) -> None: r""" Annotate interventional responsiveness for a predicted causal graph Parameters ---------- pred Predicted causal graph adata Interventional dataset interv_key Key in :attr:`~anndata.AnnData.obs` for the intervention variable """ def row_fmt(x): return x.tocsr() if issparse(x) else x def col_fmt(x): return x.tocsc() if issparse(x) else x def cohens_d(u, v): return (np.mean(v) - np.mean(u)) / np.sqrt( (np.var(u) * u.size + np.var(v) * v.size) / (u.size + v.size) ) var_names = adata.var_names targets = adata.obs[interv_key].map(Targets) all_targets = reduce(or_, targets.unique()) X = row_fmt(adata.X) ctrl = col_fmt(X[targets.map(len) == 0]) interv = {t: col_fmt(X[targets.map(lambda x: t in x)]) for t in tqdm(all_targets)} for (x, y), attr in tqdm(dict(pred.edges).items(), total=pred.number_of_edges()): if x not in interv: attr["fwd_pval"] = attr["fwd_diff"] = attr["fwd_dist"] = np.nan else: yidx = var_names.get_loc(y) u = densify(ctrl[:, yidx]).ravel() v = densify(interv[x][:, yidx]).ravel() attr["fwd_pval"] = scipy.stats.ks_2samp(u, v).pvalue attr["fwd_diff"] = cohens_d(u, v) attr["fwd_dist"] = abs(attr["fwd_diff"]) if y not in interv: attr["rev_pval"] = attr["rev_diff"] = attr["rev_dist"] = np.nan else: xidx = var_names.get_loc(x) u = densify(ctrl[:, xidx]).ravel() v = densify(interv[y][:, xidx]).ravel() attr["rev_pval"] = scipy.stats.ks_2samp(u, v).pvalue attr["rev_diff"] = cohens_d(u, v) attr["rev_dist"] = abs(attr["rev_diff"])
[docs] @_multiplex_compat_resp_metric def disc_resp_dist(pred: nx.DiGraph, cutoff: float = 0.5) -> np.floating: r""" Responsiveness distance of the predicted causal graph Parameters ---------- pred Predicted causal graph cutoff Binary classification cutoff Returns ------- Responsiveness distance """ pred = filter_edges(pred, cutoff=cutoff) return np.nanmean([attr["fwd_dist"] for attr in pred.edges.values()])
[docs] @_multiplex_compat_resp_metric def disc_resp_dist_diff(pred: nx.DiGraph, cutoff: float = 0.5) -> np.floating: r""" Responsiveness distance difference of the predicted causal graph Parameters ---------- pred Predicted causal graph cutoff Binary classification cutoff Returns ------- Responsiveness distance difference """ pred = filter_edges(pred, cutoff=cutoff) return np.nanmean( [attr["fwd_dist"] - attr["rev_dist"] for attr in pred.edges.values()] )
[docs] @_multiplex_compat_resp_metric def disc_resp_acc( pred: nx.DiGraph, cutoff: float = 0.5, sig: float = 0.1 ) -> np.floating: r""" Responsiveness accuracy of the predicted causal graph Parameters ---------- pred Predicted causal graph cutoff Binary classification cutoff sig Significance level Returns ------- Responsiveness accuracy """ pred = filter_edges(pred, cutoff=cutoff) n_fwd = len( [ e for e, attr in pred.edges.items() if attr["fwd_pval"] < sig and attr["rev_pval"] > sig ] ) n_rev = len( [ e for e, attr in pred.edges.items() if attr["fwd_pval"] > sig and attr["rev_pval"] < sig ] ) with np.errstate(invalid="ignore"): return np.divide(n_fwd, n_fwd + n_rev)
# --------------------- Counterfactual prediction metrics ---------------------- def _ctfact_prep( ctrl: AnnData, true: AnnData, pred: AnnData, by: str ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: r""" Prepare datasets for counterfactual metric computation Parameters ---------- ctrl Control dataset true True interventional effect dataset pred Predicted counterfactual effect dataset by Intervention variable to group by in the :attr:`~anndata.AnnData.obs` slot Returns ------- Control dataset aggregated by intervention True interventional effect dataset aggregated by intervention Predicted counterfactual effect dataset aggregated by intervention """ ctrl_agg = aggregate_obs(ctrl, by, X_agg="mean", obs_agg={by: "majority"}) true_agg = aggregate_obs(true, by, X_agg="mean", obs_agg={by: "majority"}) pred_agg = aggregate_obs(pred, by, X_agg="mean", obs_agg={by: "majority"}) ctrl_agg.obs.set_index(by, inplace=True) true_agg.obs.set_index(by, inplace=True) pred_agg.obs.set_index(by, inplace=True) ctrl_df = ctrl_agg.to_df() true_df = true_agg.to_df() pred_df = pred_agg.to_df() if ctrl_df.shape[0] != 1 or ctrl_df.index[0] != "": raise ValueError("Invalid control dataset") if set(true_df.index) != set(pred_df.index): raise ValueError("Interventions in `pred` does not match `true`") return ctrl_df, true_df, pred_df
[docs] def ctfact_mse( ctrl: AnnData, true: AnnData, pred: AnnData, by: str, top_de: int = None, exclude_self: bool = False, de_key: str = "rank_genes_groups", ) -> pd.DataFrame: r""" Mean squared errors of counterfactual prediction Parameters ---------- ctrl Control dataset true True interventional effect dataset pred Predicted counterfactual effect dataset by Intervention variable to group by in the :attr:`~anndata.AnnData.obs` slot top_de Number of top differentially expressed genes to consider exclude_self Whether to exclude the perturbed genes themselves de_key Key to the differential expression results Returns ------- Counterfactual metric data frame consisting of columns: - "true_mse" - "pred_mse" - "normalized_mse" """ ctrl_df, true_df, pred_df = _ctfact_prep(ctrl, true, pred, by) true_se = (true_df - ctrl_df.iloc[0]).pow(2) pred_se = (pred_df - true_df).pow(2) all_vars = set(true.var_names) true_mse, pred_mse = [], [] de_groups = [i for i in true.uns[de_key]["names"].dtype.names if i in true_se.index] for g in de_groups: exclude_vars = set(g.split(",")) if exclude_self else set() degs = sc.get.rank_genes_groups_df(true, g, key=de_key)["names"] degs = degs[degs.isin(all_vars - exclude_vars)] degs = degs.head(n=top_de or degs.size) true_mse.append(true_se.loc[g, degs].mean()) pred_mse.append(pred_se.loc[g, degs].mean()) true_mse = pd.Series(true_mse, index=de_groups) pred_mse = pd.Series(pred_mse, index=de_groups) normalized_mse = pred_mse / true_mse mse_df = pd.DataFrame( { "true_mse": true_mse, "pred_mse": pred_mse, "normalized_mse": normalized_mse, } ) return mse_df
[docs] def ctfact_delta_pcc( ctrl: AnnData, true: AnnData, pred: AnnData, by: str, top_de: int = None, exclude_self: bool = False, de_key: str = "rank_genes_groups", ) -> pd.DataFrame: r""" Pearson correlation coefficient of counterfactual delta Parameters ---------- ctrl Control dataset true True interventional effect dataset pred Predicted counterfactual effect dataset by Intervention variable to group by in the :attr:`~anndata.AnnData.obs` slot top_de Number of top differentially expressed genes to consider exclude_self Whether to exclude the perturbed genes themselves de_key Key to the differential expression results Returns ------- Counterfactual metric data frame containing column "delta_pcc" """ ctrl_df, true_df, pred_df = _ctfact_prep(ctrl, true, pred, by) true_delta = true_df - ctrl_df.iloc[0] pred_delta = pred_df - ctrl_df.iloc[0] all_vars = set(true.var_names) de_groups = [ i for i in true.uns[de_key]["names"].dtype.names if i in true_delta.index ] pcc = [] for g in de_groups: exclude_vars = set(g.split(",")) if exclude_self else set() degs = sc.get.rank_genes_groups_df(true, g, key=de_key)["names"] degs = degs[degs.isin(all_vars - exclude_vars)] degs = degs.head(n=top_de or degs.size) pcc.append(true_delta.loc[g, degs].corr(pred_delta.loc[g, degs])) pcc = pd.Series(pcc, index=de_groups) pcc_df = pd.DataFrame({"delta_pcc": pcc}) return pcc_df
[docs] def ctfact_dir_acc( ctrl: AnnData, true: AnnData, pred: AnnData, by: str, top_de: int = None, exclude_self: bool = False, de_key: str = "rank_genes_groups", ) -> pd.DataFrame: r""" Directional accuracy of counterfactual predictions Parameters ---------- ctrl Control dataset true True interventional effect dataset pred Predicted counterfactual effect dataset by Intervention variable to group by in the :attr:`~anndata.AnnData.obs` slot top_de Number of top differentially expressed genes to consider exclude_self Whether to exclude the perturbed genes themselves de_key Key to the differential expression results Returns ------- Counterfactual metric data frame containing column "dir_acc" """ ctrl_df, true_df, pred_df = _ctfact_prep(ctrl, true, pred, by) true_sign = np.sign(true_df - ctrl_df.iloc[0]) pred_sign = np.sign(pred_df - ctrl_df.iloc[0]) sign_match = true_sign.eq(pred_sign) all_vars = set(true.var_names) de_groups = [ i for i in true.uns[de_key]["names"].dtype.names if i in sign_match.index ] acc = [] for g in de_groups: exclude_vars = set(g.split(",")) if exclude_self else set() degs = sc.get.rank_genes_groups_df(true, g, key=de_key)["names"] degs = degs[degs.isin(all_vars - exclude_vars)] degs = degs.head(n=top_de or degs.size) acc.append(sign_match.loc[g, degs].mean()) acc = pd.Series(acc, index=de_groups) acc_df = pd.DataFrame({"dir_acc": acc}) return acc_df
# ------------------------ Intervention design metrics -------------------------
[docs] def dsgn_hrc_exact(designs: dict[str, pd.Series]) -> tuple[np.ndarray, np.ndarray]: r""" Exact hit-rate curve for intervention design Parameters ---------- designs Mapping from true interventions to designed intervention scores Returns ------- Quantiles of the designed intervention scores Exact hit rates at the quantiles """ l = [] for true, design in designs.items(): rank = design.rank(ascending=False) try: l.append(rank.loc[true] / rank.size) except KeyError: logger.warning(f"Exact intervention {true} not found in its design!") l.append(1.0) qtl = np.asarray([0.0, *sorted(l), 1.0]) hr = np.asarray([*[i / len(l) for i in range(len(l) + 1)], 1.0]) return qtl, hr
[docs] def dsgn_hrc_partial(designs: dict[str, pd.Series]) -> tuple[np.ndarray, np.ndarray]: r""" Partial hit-rate curve for intervention design Parameters ---------- designs Mapping from true interventions to designed intervention scores Returns ------- Quantiles of the designed intervention scores Partial hit rates at the quantiles """ l = [] for true, design in designs.items(): rank = design.rank(ascending=False) true = Targets(true) partial_match = rank.index.map(lambda x: bool(Targets(x) & true)) if partial_match.any(): l.append(rank.loc[partial_match].min() / rank.size) else: logger.warning(f"No partial match for {true} found in its design!") l.append(1.0) qtl = np.asarray([0.0, *sorted(l), 1.0]) hr = np.asarray([*[i / len(l) for i in range(len(l) + 1)], 1.0]) return qtl, hr
[docs] def dsgn_auhrc_exact(designs: dict[str, pd.Series]) -> np.floating: r""" Area under the exact hit-rate curve for intervention design, see :func:`dsgn_hrc_exact` Parameters ---------- designs Mapping from true interventions to designed intervention scores Returns ------- Area under the exact hit-rate curve """ qtl, hr = dsgn_hrc_exact(designs) return auc(qtl, hr)
[docs] def dsgn_auhrc_partial(designs: dict[str, pd.Series]) -> np.floating: r""" Area under the partial hit-rate curve for intervention design, see :func:`dsgn_hrc_partial` Parameters ---------- designs Mapping from true interventions to designed intervention scores Returns ------- Area under the partial hit-rate curve """ qtl, hr = dsgn_hrc_partial(designs) return auc(qtl, hr)