r"""
Miscellaneous utilities
"""
import os
import sys
import uuid
from collections import defaultdict
from collections.abc import Hashable, Iterable
from functools import lru_cache
from heapq import nlargest
from multiprocessing import Process, Queue
from os import cpu_count, environ
from random import shuffle
from sys import stderr, stdout
import numpy as np
import pandas as pd
import pynvml
import torch
from loguru import logger
from numpy.typing import ArrayLike
from rich.console import Console
from rich.theme import Theme
from scipy import stats
from scipy.cluster.hierarchy import linkage
from scipy.linalg import pinvh
from scipy.sparse import issparse, spmatrix
from scipy.spatial.distance import pdist
from scipy.stats import rankdata
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from .typing import RandomState
[docs]
def internal(obj):
r"""
Decorator to mark a function or class as internal and exclude it from
autosummary.
"""
obj._internal = True
return obj
[docs]
class Config:
r"""
Global configurations
"""
def __init__(self) -> None:
self.ANNDATA_KEY = environ.get("CASCADE_ANNDATA_KEY", "__CASCADE__")
self.NUM_WORKERS = environ.get("CASCADE_NUM_WORKERS", min(4, cpu_count()))
self.PERSISTENT_WORKERS = environ.get("CASCADE_PERSISTENT_WORKERS", True)
self.DETERMINISTIC = environ.get("CASCADE_DETERMINISTIC", False)
self.CUDA_REMAP = environ.get("CASCADE_CUDA_REMAP", False)
self.LOG_STEP_INTERVAL = environ.get("CASCADE_LOG_STEP_INTERVAL", 10)
self.PBAR_REFRESH = environ.get("CASCADE_PBAR_REFRESH", 10)
self.CKPT_SAVE_K = environ.get("CASCADE_CKPT_SAVE_K", 1)
self.MIN_DELTA = environ.get("CASCADE_MIN_DELTA", 0.0)
self.PATIENCE = environ.get("CASCADE_PATIENCE", 3)
self.PRECISION = environ.get("CASCADE_PRECISION", "32-true")
self.LOG_LEVEL = environ.get("CASCADE_LOG_LEVEL", "INFO")
self.RUN_ID = environ.get("CASCADE_RUN_ID", str(uuid.uuid4()))
self.DEBUG_FLAG = environ.get("CASCADE_DEBUG_FLAG", False)
@property
def ANNDATA_KEY(self) -> str:
r"""
Key to store data configuration in :class:`~anndata.AnnData`.
Default value is ``__CASCADE__``.
"""
return self._ANNDATA_KEY
@ANNDATA_KEY.setter
def ANNDATA_KEY(self, value: str) -> None:
self._ANNDATA_KEY = value
@property
def NUM_WORKERS(self) -> int:
r"""
Number of worker processes to use in data loader.
Default value is ``min(4, n_cpu)``.
.. tip::
If training hangs randomly, try setting this option to ``0``,
or setting ``PERSISTENT_WORKERS = False``.
"""
return self._NUM_WORKERS
@NUM_WORKERS.setter
def NUM_WORKERS(self, value: int | str) -> None:
self._NUM_WORKERS = int(value)
@property
def PERSISTENT_WORKERS(self) -> bool:
r"""
Whether to use persistent workers in data loader.
Default value is ``True``.
.. tip::
This option is only effective when ``NUM_WORKERS > 0``.
If training hangs randomly, try setting this option to ``False``,
or setting ``NUM_WORKERS = 0``.
"""
return self._PERSISTENT_WORKERS and self.NUM_WORKERS > 0
@PERSISTENT_WORKERS.setter
def PERSISTENT_WORKERS(self, value: bool | str) -> None:
self._PERSISTENT_WORKERS = (
str_to_bool(value) if isinstance(value, str) else value
)
@property
def DETERMINISTIC(self) -> bool:
r"""
Whether to use deterministic cuDNN implementations.
Default value is ``False``.
"""
return self._DETERMINISTIC
@DETERMINISTIC.setter
def DETERMINISTIC(self, value: bool | str) -> None:
self._DETERMINISTIC = str_to_bool(value) if isinstance(value, str) else value
@property
def CUDA_REMAP(self) -> bool:
r"""
Whether to remap granted CUDA device IDs.
Default value is ``False``.
"""
return self._CUDA_REMAP
@CUDA_REMAP.setter
def CUDA_REMAP(self, value: bool | str) -> None:
self._CUDA_REMAP = str_to_bool(value) if isinstance(value, str) else value
@property
def LOG_STEP_INTERVAL(self) -> int:
r"""
Refresh rate of the training progress bar.
Default value is ``10``.
"""
return self._LOG_STEP_INTERVAL
@LOG_STEP_INTERVAL.setter
def LOG_STEP_INTERVAL(self, value: int | str) -> None:
self._LOG_STEP_INTERVAL = int(value)
@property
def PBAR_REFRESH(self) -> int:
r"""
Refresh rate of the training progress bar.
Default value is ``10``.
"""
return self._PBAR_REFRESH
@PBAR_REFRESH.setter
def PBAR_REFRESH(self, value: int | str) -> None:
self._PBAR_REFRESH = int(value)
@property
def CKPT_SAVE_K(self) -> int:
r"""
Number of top models to save as checkpoints.
Default values is ``3``.
"""
return self._CKPT_SAVE_K
@CKPT_SAVE_K.setter
def CKPT_SAVE_K(self, value: int | str) -> None:
self._CKPT_SAVE_K = int(value)
@property
def MIN_DELTA(self) -> float:
r"""
Minimal score improvement call convergence in earlystopping.
Default value is ``0.0``.
"""
return self._MIN_DELTA
@MIN_DELTA.setter
def MIN_DELTA(self, value: float | str) -> None:
self._MIN_DELTA = float(value)
@property
def PATIENCE(self) -> int:
r"""
Patience to call convergence in earlystopping.
Default value is ``3``.
"""
return self._PATIENCE
@PATIENCE.setter
def PATIENCE(self, value: int | str) -> None:
self._PATIENCE = int(value)
@property
def PRECISION(self) -> int | str:
r"""
Floating point precision.
Default value is ``32-true``.
"""
return self._PRECISION
@PRECISION.setter
def PRECISION(self, value: int | str = "32-true") -> None:
precision_map = {
"bf16": torch.bfloat16,
"bf16-mixed": torch.bfloat16,
16: torch.float16,
"16": torch.float16,
"16-mixed": torch.float16,
32: torch.float32,
"32": torch.float32,
"32-true": torch.float32,
64: torch.float64,
"64": torch.float64,
"64-true": torch.float64,
}
fallback_map = {
"bf16": torch.float32,
"bf16-mixed": torch.float32,
16: torch.float32,
"16": torch.float32,
"16-mixed": torch.float32,
32: torch.float32,
"32": torch.float32,
"32-true": torch.float32,
64: torch.float64,
"64": torch.float64,
"64-true": torch.float64,
}
torch.set_default_dtype(precision_map[value])
self._FALLBACK_DTYPE = fallback_map[value]
self._PRECISION = value
@property
def FALLBACK_DTYPE(self) -> torch.dtype:
return self._FALLBACK_DTYPE
@property
def LOG_LEVEL(self) -> str:
r"""
Log level.
Default value is ``"INFO"``.
"""
return self._LOG_LEVEL
@LOG_LEVEL.setter
def LOG_LEVEL(self, value: str) -> None:
logger.remove()
logger.add(
stdout,
filter=lambda record: record["level"].no < 30,
level=value,
format=(
"<g>{time:HH:mm:ss.SSS}</g> | "
"<lvl>{level: <8}</lvl> | "
"<y>{process.id}</y>:<c>{module}</c>:<c>{function}</c> - "
"<lvl>{message}</lvl>"
),
)
logger.add(
stderr,
filter=lambda record: record["level"].no >= 30,
level=value,
format=(
"<g>{time:HH:mm:ss.SSS}</g> | "
"<lvl>{level: <8}</lvl> | "
"<y>{process.id}</y>:<c>{module}</c>:<c>{function}</c> - "
"<lvl>{message}</lvl>"
),
)
self._LOG_LEVEL = value
@property
def RUN_ID(self) -> str:
r"""
A unique UUID for the running session
"""
return self._RUN_ID
@RUN_ID.setter
def RUN_ID(self, value: str) -> None:
self._RUN_ID = value
environ["CASCADE_RUN_ID"] = self._RUN_ID
@property
def DEBUG_FLAG(self) -> bool:
r"""
Convenience utility for setting conditional breakpoints without
affecting running workflows.
Default value is ``False``.
"""
return self._DEBUG_FLAG
@DEBUG_FLAG.setter
def DEBUG_FLAG(self, value: bool | str) -> None:
self._DEBUG_FLAG = str_to_bool(value) if isinstance(value, str) else value
[docs]
class MissingDependencyError(Exception):
def __init__(self, name: str) -> None:
super().__init__(f"Please install {name} first.") # pragma: no cover
[docs]
def is_notebook() -> bool: # pragma: no cover
r"""
Check if the code is running in a Jupyter notebook
Returns
-------
Whether the code is running in a Jupyter notebook
"""
try:
shell = type(get_ipython()).__name__ # type: ignore
if shell == "ZMQInteractiveShell":
return True
return False
except NameError:
return False
[docs]
def str_to_bool(x: str) -> bool:
r"""
Interpret string as bool
Parameters
----------
x
String to interpret
Returns
-------
Interpreted bool
"""
if x in ("T", "True", "true", "1"):
return True
if x in ("F", "False", "false", "0"):
return False
raise ValueError(f"Cannot interpret {x} as bool")
@internal
def non_unitary_index(
index: (
int | slice | range | list[int] | tuple[int | slice | range | list[int], ...]
),
) -> slice | range | list[int] | tuple[slice | range | list[int], ...]:
if isinstance(index, tuple):
return tuple(non_unitary_index(i) for i in index)
return [index] if isinstance(index, int) else index
@internal
def index_len(
index: (
int | slice | range | list[int] | tuple[int | slice | range | list[int], ...]
),
total: int | tuple[int],
) -> int | tuple[int]:
if not isinstance(index, tuple) and not isinstance(total, tuple):
if isinstance(index, int):
return 1
if isinstance(index, slice):
index = range(total)[index]
return len(index)
if isinstance(index, tuple) and not isinstance(total, tuple):
raise ValueError("Inconsistent total")
if not isinstance(index, tuple) and isinstance(total, tuple):
index = (index,)
# Now both are tuples
if len(index) > len(total):
raise IndexError("Too many indices")
index = index + (slice(None),) * (len(total) - len(index))
return tuple(index_len(i, t) for i, t in zip(index, total))
[docs]
def get_random_state(random_state: RandomState = None) -> np.random.RandomState:
r"""
Get a random state object
Parameters
----------
random_state
Integer seed, existing :class:`~numpy.random.RandomState` object, or
None
Returns
-------
Random state object
"""
if isinstance(random_state, np.random.RandomState):
return random_state
return np.random.RandomState(random_state)
[docs]
def count_occurrence(x: Iterable[Hashable]) -> list[int]:
r"""
Count occurrence number of list elements
Parameters
----------
x
List of hashable elements
Returns
-------
List of occurrence counts
"""
counter = defaultdict(int)
occurrence = []
for element in x:
occurrence.append(counter[element])
counter[element] += 1
return occurrence
[docs]
def densify(x: np.ndarray | spmatrix) -> np.ndarray:
r"""
Convert a matrix to dense format
Parameters
----------
x
Input matrix
Returns
-------
Dense matrix
"""
return x.toarray() if issparse(x) else x
[docs]
def variance(x: ArrayLike | spmatrix, bias: bool = False) -> np.ndarray:
r"""
Compute variance vector where each column of the input matrix is treated as
a variable
Parameters
----------
x
Input matrix
bias
Whether to compute biased variance
Returns
-------
Variance vector
"""
if issparse(x):
mean = x.mean(axis=0).A1
var = x.power(2).mean(axis=0).A1 - mean**2
if not bias:
n = x.shape[0]
var = var * n / (n - 1)
else:
var = np.var(np.asarray(x), axis=0, ddof=int(not bias))
return var
[docs]
def covariance(x: ArrayLike | spmatrix, bias: bool = False) -> np.ndarray:
r"""
Compute covariance matrix where each column of the input matrix is treated
as a variable
Parameters
----------
x
Input matrix
bias
Whether to compute biased covariance
Returns
-------
Covariance matrix
"""
if issparse(x):
n = x.shape[0]
mean = x.mean(axis=0).A1
cov = (x.T @ x).toarray() / n - np.outer(mean, mean)
if not bias:
cov = cov * n / (n - 1)
else:
cov = np.cov(np.asarray(x), rowvar=False, bias=bias)
return cov
[docs]
def pearson_correlation(x: ArrayLike | spmatrix) -> np.ndarray:
r"""
Compute Pearson correlation matrix
Parameters
----------
x
Input matrix
Returns
-------
Pearson correlation matrix
"""
cov = covariance(x)
diag = np.sqrt(np.diag(cov))
return cov / diag[np.newaxis, :] / diag[:, np.newaxis]
[docs]
def partial_correlation(x: ArrayLike | spmatrix) -> np.ndarray:
r"""
Compute partial correlation matrix
Parameters
----------
x
Input matrix
Returns
-------
Partial correlation matrix
"""
cov = covariance(x)
prec = pinvh(cov)
diag = np.sqrt(np.diag(prec))
return -prec / diag[np.newaxis, :] / diag[:, np.newaxis]
[docs]
def spearman_correlation(x: ArrayLike | spmatrix) -> np.ndarray:
r"""
Compute Spearman correlation matrix
Parameters
----------
x
Input matrix
Returns
-------
Spearman correlation matrix
"""
if issparse(x):
x = x.toarray()
x = np.stack([rankdata(col) for col in x.T], axis=1)
return pearson_correlation(x)
[docs]
def hclust(
X: pd.DataFrame,
metric: str = "euclidean",
method: str = "complete",
cut: bool = True,
**kwargs,
) -> tuple[np.ndarray, pd.Series]:
r"""
Hierarchical clustering followed by optional tree cutting
Parameters
----------
X
Input data
metric
Distance metric
method
Clustering method
cut
Whether to cut the tree
**kwargs
Additional keyword arguments for tree cutting passed to
:func:`~dynamicTreeCut.cutreeHybrid`
Returns
-------
Linkage matrix
Cluster labels
"""
D = pdist(X, metric=metric)
D[np.isnan(D)] = np.nanmax(D)
L = linkage(D, method=method)
if cut:
try:
from dynamicTreeCut import cutreeHybrid
C = cutreeHybrid(L, D, **kwargs)
C = pd.Series(C["labels"], index=X.index)
except ImportError: # pragma: no cover
logger.warning("dynamicTreeCut not found, skipping tree cut...")
C = None
else:
C = None
return L, C
def _search_right2left(x: ArrayLike) -> tuple[int, int]:
r"""
Search a boolean 1D array from right to left and return the index of the
first True value :math:`i` and the index of the last False value :math:`j`.
- If the rightmost value is True, both :math`i` and :math:`j` are the
rightmost index.
- If all values are False, both :math`i` and :math:`j` are the
leftmost index.
"""
x = np.asarray(x)
if x.ndim != 1 or x.size == 0:
raise ValueError("Invalid array")
for i in range(x.size - 1, -1, -1):
if x[i]:
return i, min(i + 1, x.size - 1)
else:
return i, i
[docs]
def gp_regression_with_ci(
data: pd.DataFrame, x: str, y: str, alpha: float = 0.95
) -> tuple[pd.DataFrame, float]:
r"""
Gaussian process regression with confidence interval
Parameters
----------
data
Input data frame
x
Input variable
y
Output variable
alpha
Confidence level
Returns
-------
Data frame with three additional columns of mean, lower and upper bounds
Cutoff of input variable that covers minimal output value in CI
"""
data_clean = data.dropna()
gp = GaussianProcessRegressor(kernel=RBF() + WhiteKernel())
gp.fit(np.asarray(data_clean[[x]]), np.asarray(data_clean[y]))
y_mean, y_std = gp.predict(np.asarray(data[[x]]), return_std=True)
lower = stats.norm.ppf((1 - alpha) / 2)
upper = stats.norm.ppf(1 - (1 - alpha) / 2)
y_lower = y_mean + lower * y_std
y_upper = y_mean + upper * y_std
x_min, x_max, y_min = data[x].min(), data[x].max(), data[y].min()
x_ = np.linspace(x_min, x_max, 1000)
y_mean_, y_std_ = gp.predict(x_[:, np.newaxis], return_std=True)
y_lower_ = y_mean_ + lower * y_std_
cut_left, cut_right = _search_right2left(y_lower_ > y_min)
cutoff = (x_[cut_left] + x_[cut_right]) / 2
return (
data.assign(
**{f"{y}_mean": y_mean, f"{y}_lower": y_lower, f"{y}_upper": y_upper}
),
cutoff,
)
[docs]
@lru_cache
def autodevice(n: int = 1) -> tuple[str, list[int] | str]:
r"""
Get torch computation device automatically based on GPU availability and
memory usage
Parameters
----------
n
Number of GPUs to request
Returns
-------
Accelerator type
List of granted devices
"""
granted = environ.get("CASCADE_GRANTED", None)
if granted is not None:
granted = [int(i) for i in granted.split(",")]
if len(granted) == n: # DDP member
logger.info("Using GPU {} as computation device.", granted)
return "gpu", granted
try:
pynvml.nvmlInit()
devices = environ.get("CUDA_VISIBLE_DEVICES", None)
devices = (
list(range(pynvml.nvmlDeviceGetCount()))
if devices is None
else [int(d.strip()) for d in devices.split(",") if d != ""]
)
shuffle(devices)
free_mems = {
i: pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(i)).free
for i in devices
}
granted = nlargest(n, free_mems, free_mems.get)
if len(granted) == 0:
raise pynvml.NVMLError("GPU disabled.")
if config.CUDA_REMAP:
logger.info("Remapping CUDA device IDs.")
granted = list(range(len(granted)))
logger.info("Using GPU {} as computation device.", granted)
environ["CASCADE_GRANTED"] = ",".join(str(i) for i in granted)
return "gpu", granted
except pynvml.NVMLError: # pragma: no cover
logger.info("Using CPU as computation device.")
return "cpu", "auto"
def _subprocess_affinity(q: Queue) -> None:
q.put(os.sched_getaffinity(0))
[docs]
def check_affinity_inheritance() -> None:
r"""
Check whether affinity inheritance is broken
"""
if sys.platform != "linux" or not hasattr(os, "sched_getaffinity"):
return # no-op on macOS/Windows
q = Queue()
p = Process(target=_subprocess_affinity, args=(q,))
p.start()
p.join()
if q.get() != os.sched_getaffinity(0):
logger.warning(
"Affinity inheritance is broken! This might be related to "
"https://github.com/pytorch/pytorch/issues/99625. "
"Consider setting environment variable KMP_AFFINITY=disabled. "
"Otherwise, performance would be compromised."
)
config = Config()
console = Console(theme=Theme({"hl": "bold magenta"}))