r"""
Dataset processing utilities and data loaders
"""
import re
from collections.abc import Iterable, Mapping
from functools import reduce
from operator import or_
import numpy as np
import pandas as pd
import torch
from anndata import AnnData
from loguru import logger
from pytorch_lightning import LightningDataModule
from scipy.sparse import csr_matrix, issparse, lil_matrix, spmatrix
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import DataLoader, Dataset
from .typing import RandomState
from .utils import config, get_random_state, internal
EPS: float = 1e-7
[docs]
class Targets(frozenset[str]):
r"""
Intervention targets
Parameters
----------
targets
A string of comma-separated target names
.. note::
Use empty string for no targets
"""
SEP = re.compile(r"\s*,\s*")
def __new__(cls, targets: str | Iterable[str] | None = None) -> None:
if isinstance(targets, str):
targets = (item for item in re.split(cls.SEP, targets) if item)
elif targets is None:
targets = ()
return super().__new__(cls, targets)
def __lt__(self, other: "Targets") -> bool:
if len(self) < len(other):
return True
if len(self) > len(other):
return False
for i, j in zip(sorted(self), sorted(other)):
if i < j:
return True
return False
def __gt__(self, other: "Targets") -> bool:
if len(self) > len(other):
return True
if len(self) < len(other):
return False
for i, j in zip(sorted(self), sorted(other)):
if i > j:
return True
return False
def __le__(self, other: "Targets") -> bool:
return self == other or self < other
def __ge__(self, other: "Targets") -> bool:
return self == other or self > other
def __and__(self, other: Iterable[str]) -> "Targets":
return Targets(super().__and__(set(other)))
def __or__(self, other: Iterable[str]) -> "Targets":
return Targets(super().__or__(set(other)))
def __xor__(self, other: Iterable[str]) -> "Targets":
return Targets(super().__xor__(set(other)))
def __sub__(self, other: Iterable[str]) -> "Targets":
return Targets(super().__sub__(set(other)))
def __repr__(self) -> str:
return ",".join(sorted(self))
[docs]
def get_all_targets(adata: AnnData, key: str) -> Targets:
r"""
Get the union of all intervention targets in a give dataset
Parameters
----------
adata
Input dataset
key
Key in :attr:`~anndata.AnnData.obs` containing comma-separated
intervention targets
Returns
-------
Union of all intervention targets
"""
return reduce(or_, np.vectorize(Targets)(adata.obs[key].unique()))
[docs]
def filter_unobserved_targets(adata: AnnData, key: str) -> AnnData:
r"""
Filter cells where the intervention targets are not observed (missing from
:attr:`~anndata.AnnData.var_names`)
Parameters
----------
adata
Input dataset
key
Key in :attr:`~anndata.AnnData.obs` containing comma-separated
intervention targets
Returns
-------
Filtered dataset
"""
all_vars = set(adata.var_names)
unobserved = {k: bool(Targets(k) - all_vars) for k in adata.obs[key].unique()}
return adata[~adata.obs[key].map(unobserved)]
[docs]
def encode_regime(adata: AnnData, layer: str, key: str = None) -> None:
r"""
Encode intervention regime
Parameters
----------
adata
Input dataset
layer
Layer in :attr:`~anndata.AnnData.layers` to write the encoded regime
key
Column in :attr:`~anndata.AnnData.obs` containing comma-separated
intervention targets following :class:`Targets` format
"""
regime = lil_matrix(adata.shape, dtype=bool)
if key is not None:
targets = adata.obs[key].map(Targets)
rows = np.concatenate([np.repeat(i, len(t)) for i, t in enumerate(targets)])
cols = np.concatenate([adata.var_names.get_indexer(t) for t in targets])
if cols.size and cols.min() < 0:
raise ValueError("Invalid intervention target")
regime[rows, cols] = True
if layer in adata.layers:
logger.warning(f'Overwriting existing regime "{layer}".')
adata.layers[layer] = regime.tocsr().astype(bool)
[docs]
def get_configuration(adata: AnnData) -> dict[str, str]:
r"""
Retrieve the configuration by :func:`configure_dataset`
Parameters
----------
adata
Input dataset
Returns
-------
Configuration dictionary
"""
if config.ANNDATA_KEY not in adata.uns:
raise KeyError(
"Dataset not configured yet, please call `configure_dataset` first."
)
return adata.uns[config.ANNDATA_KEY]
def _get_X(adata: AnnData) -> csr_matrix | np.ndarray:
key = get_configuration(adata).get("layer", None)
if key is None:
logger.debug("Dataset not configured with `layer`, using `adata.X`.")
X = adata.X
else:
logger.debug(f"Using configured `layer`: {key}.")
X = adata.layers[key]
return X.tocsr() if issparse(X) else X
def _set_X(adata: AnnData, X: spmatrix | np.ndarray) -> None:
key = get_configuration(adata).get("layer", None)
if key is None:
logger.debug("Dataset not configured with `layer`, using `adata.X`.")
adata.X = X
else:
logger.debug(f"Using configured `layer`: {key}.")
adata.layers[key] = X
def _get_regime(adata: AnnData) -> csr_matrix:
key = get_configuration(adata).get("regime", None)
if key is None:
logger.debug("Dataset not configured with `regime`, assuming observational.")
return csr_matrix(adata.shape, dtype=bool)
logger.debug(f"Using configured `regime`: {key}.")
regime = adata.layers[key].tocsr()
regime.eliminate_zeros()
return regime
def _set_regime(adata: AnnData, regime: spmatrix) -> None:
key = get_configuration(adata).get("regime", None)
if key is None:
raise ValueError("Dataset not configured with `regime`")
logger.debug(f"Using configured `regime`: {key}.")
adata.layers[key] = regime
def _get_covariate(adata: AnnData) -> np.ndarray:
key = get_configuration(adata).get("covariate", None)
if key is None:
logger.debug("Dataset not configured with `covariate`, ignoring.")
return np.empty((adata.n_obs, 0))
logger.debug(f"Using configured `covariate`: {key}.")
return np.asarray(adata.obsm[key])
def _set_covariate(adata: AnnData, covariate: np.ndarray) -> None:
key = get_configuration(adata).get("covariate", None)
if key is None:
raise ValueError("Dataset not configured with `covariate`")
logger.debug(f"Using configured `covariate`: {key}.")
adata.obsm[key] = covariate
def _get_size(adata: AnnData) -> np.ndarray:
key = get_configuration(adata).get("size", None)
if key is None:
logger.debug("Dataset not configured with `size`, ignoring.")
return np.empty((adata.n_obs, 0))
logger.debug(f"Using configured `size`: {key}.")
return np.asarray(adata.obs[[key]])
def _set_size(adata: AnnData, covariate: np.ndarray) -> None:
key = get_configuration(adata).get("size", None)
if key is None:
raise ValueError("Dataset not configured with `size`")
logger.debug(f"Using configured `size`: {key}.")
adata.obs[key] = covariate.ravel()
def _get_weight(adata: AnnData) -> np.ndarray:
key = get_configuration(adata).get("weight", None)
if key is None:
logger.debug("Dataset not configured with `weight`, using unitary.")
return np.ones(adata.n_obs)
logger.debug(f"Using configured `weight`: {key}.")
weight = np.asarray(adata.obs[key])
weight = (weight.size / weight.sum()) * weight
return weight
def _set_weight(adata: AnnData, weight: np.ndarray) -> None:
key = get_configuration(adata).get("weight", None)
if key is None:
raise ValueError("Dataset not configured with `weight`")
logger.debug(f"Using configured `weight`: {key}.")
adata.obs[key] = weight.ravel()
[docs]
def neighbor_impute(
adata: AnnData,
k: int,
use_rep: str,
use_batch: str | None = None,
X_agg: str | None = "sum",
obs_agg: Mapping[str, str] | None = None,
obsm_agg: Mapping[str, str] | None = None,
layers_agg: Mapping[str, str] | None = None,
) -> AnnData:
r"""
Impute data by aggregating nearest neighbors
Parameters
----------
adata
Dataset to be imputed
k
Number of nearest neighbors
use_rep
Key in :attr:`~anndata.AnnData.obsm` containing the representation to be
used for nearest neighbor search
use_batch
Key in :attr:`~anndata.AnnData.obs` used to group cells for nearest
neighbor search (e.g., intervention label)
X_agg
Aggregation function for :attr:`~anndata.AnnData.X`, must be one of
``{"sum", "mean", ``None``}``. Setting to ``None`` discards the
:attr:`~anndata.AnnData.X` matrix.
obs_agg
Aggregation methods for :attr:`~anndata.AnnData.obs`, indexed by obs
columns, must be one of ``{"sum", "mean"}``. Fields not specified will
be discarded.
obsm_agg
Aggregation methods for :attr:`~anndata.AnnData.obsm`, indexed by obsm
keys, must be one of ``{"sum", "mean"}``. Fields not specified will be
discarded.
layers_agg
Aggregation methods for :attr:`~anndata.AnnData.layers`, indexed by
layer keys, must be one of ``{"sum", "mean"}``. Fields not specified
will be discarded.
Returns
-------
Imputed dataset
"""
obs_agg = obs_agg or {}
obsm_agg = obsm_agg or {}
layers_agg = layers_agg or {}
rows, cols = [], []
groupby = adata.obs.groupby(use_batch or np.zeros(adata.n_obs), observed=True)
for idx in groupby.indices.values():
rep = adata[idx].obsm[use_rep]
knn = NearestNeighbors().fit(rep)
knn = knn.kneighbors(rep, min(k, idx.size), return_distance=False)
rows.append(np.repeat(idx, knn.shape[1]))
cols.append(idx[knn].ravel())
rows = np.concatenate(rows)
cols = np.concatenate(cols)
agg_sum = csr_matrix(
(np.ones(rows.size), (rows, cols)), shape=(adata.n_obs, adata.n_obs)
)
agg_mean = agg_sum.multiply(1 / agg_sum.sum(axis=1))
agg_method = {
"sum": lambda x: (agg_sum @ x.reshape(x.shape[0], -1)).reshape(
-1, *x.shape[1:]
),
"mean": lambda x: (agg_mean @ x.reshape(x.shape[0], -1)).reshape(
-1, *x.shape[1:]
),
}
X = agg_method[X_agg](adata.X) if X_agg and adata.X is not None else None
obs = pd.DataFrame(
{
k: agg_method[obs_agg[k]](v.to_numpy()) if k in obs_agg else v
for k, v in adata.obs.items()
},
index=adata.obs_names,
)
obsm = {
k: agg_method[obsm_agg[k]](v) if k in obsm_agg else v
for k, v in adata.obsm.items()
}
layers = {
k: agg_method[layers_agg[k]](v) if k in layers_agg else v
for k, v in adata.layers.items()
}
return AnnData(
X=X,
obs=obs,
var=adata.var,
uns=adata.uns,
obsm=obsm,
varm=adata.varm,
layers=layers,
)
[docs]
def aggregate_obs(
adata: AnnData,
by: str,
X_agg: str | None = None,
obs_agg: Mapping[str, str] | None = None,
obsm_agg: Mapping[str, str] | None = None,
layers_agg: Mapping[str, str] | None = None,
) -> AnnData:
r"""
Aggregate obs in a given dataset by certain categories
Parameters
----------
adata
Dataset to be aggregated
by
Specify a column in :attr:`~anndata.AnnData.obs` used for aggregation,
must be discrete.
X_agg
Aggregation function for :attr:`~anndata.AnnData.X`, must be one of
``{"sum", "mean", ``None``}``. Setting to ``None`` discards the
``adata.X`` matrix.
obs_agg
Aggregation methods for :attr:`~anndata.AnnData.obs`, indexed by obs
columns, must be one of ``{"sum", "mean", "majority"}``, where ``"sum"``
and ``"mean"`` are for continuous data, and ``"majority"`` is for
discrete data. Fields not specified will be discarded.
obsm_agg
Aggregation methods for :attr:`~anndata.AnnData.obsm`, indexed by obsm
keys, must be one of ``{"sum", "mean"}``. Fields not specified will be
discarded.
layers_agg
Aggregation methods for :attr:`~anndata.AnnData.layers`, indexed by
layer keys, must be one of ``{"sum", "mean"}``. Fields not specified
will be discarded.
Returns
-------
Aggregated dataset
"""
obs_agg = obs_agg or {}
obsm_agg = obsm_agg or {}
layers_agg = layers_agg or {}
by = adata.obs[by]
agg_idx = pd.Index(by.unique())
agg_sum = csr_matrix(
(np.ones(adata.n_obs), (agg_idx.get_indexer(by), np.arange(adata.n_obs)))
)
agg_mean = agg_sum.multiply(1 / agg_sum.sum(axis=1))
agg_method = {
"sum": lambda x: (agg_sum @ x.reshape(x.shape[0], -1)).reshape(
-1, *x.shape[1:]
),
"mean": lambda x: (agg_mean @ x.reshape(x.shape[0], -1)).reshape(
-1, *x.shape[1:]
),
"majority": lambda x: pd.crosstab(by, x).idxmax(axis=1).loc[agg_idx].to_numpy(),
}
X = agg_method[X_agg](adata.X) if X_agg and adata.X is not None else None
obs = pd.DataFrame(
{k: agg_method[v](adata.obs[k].to_numpy()) for k, v in obs_agg.items()},
index=agg_idx.astype(str),
)
obsm = {k: agg_method[v](adata.obsm[k]) for k, v in obsm_agg.items()}
layers = {k: agg_method[v](adata.layers[k]) for k, v in layers_agg.items()}
for c in obs:
if isinstance(adata.obs[c].dtype, pd.CategoricalDtype):
obs[c] = pd.Categorical(obs[c], categories=adata.obs[c].cat.categories)
return AnnData(
X=X, obs=obs, var=adata.var, obsm=obsm, varm=adata.varm, layers=layers
)
[docs]
def simple_design(
interv: AnnData, target: AnnData, key: str, target_weight: str | None = None
) -> pd.DataFrame:
r"""
Perform simple intervention design by directly comparing the outcome of
seen interventions with the target
Parameters
----------
interv
Interventional data
target
Target data
key
Column in ``interv.obs`` containing comma-separated intervention targets
following :class:`Targets` format
target_weight
Variable weights for computing "mse" deviation with target
Returns
-------
Simple design with an "mse" column sorted by ascending order
"""
if key in target.obs:
raise ValueError(f"`target.obs` must not contain '{key}'") # pragma: no cover
target.obs[key] = "target"
interv_agg = aggregate_obs(interv, by=key, X_agg="mean").to_df()
target_agg = aggregate_obs(target, by=key, X_agg="mean").to_df().iloc[0]
del target.obs[key]
if target_weight is not None:
weight = target.var[target_weight]
weight = weight.size * weight / weight.sum()
else:
weight = pd.Series(1, index=target.var_names)
mse = interv_agg.sub(target_agg).pow(2).mul(weight).mean(axis=1)
return pd.DataFrame({"mse": mse}).sort_values("mse", kind="stable")
[docs]
class SimpleDataset(Dataset):
r"""
A single interventional dataset
Parameters
----------
adata
Interventional dataset
"""
def __init__(self, adata: AnnData) -> None:
self.n = adata.n_obs
self.x = _get_X(adata)
self.r = _get_regime(adata)
self.s = _get_covariate(adata)
self.l = _get_size(adata)
self.w = _get_weight(adata)
def __len__(self) -> int:
return self.n
def __getitem__(self, index: int) -> list[torch.Tensor]:
default_dtype = torch.get_default_dtype()
x = torch.as_tensor(
self.x[index].toarray().squeeze(0) if issparse(self.x) else self.x[index],
dtype=default_dtype,
)
r = torch.as_tensor(self.r[index].toarray().squeeze(0), dtype=default_dtype)
s = torch.as_tensor(self.s[index], dtype=default_dtype)
l = torch.as_tensor(self.l[index], dtype=default_dtype)
w = torch.as_tensor(self.w[index], dtype=default_dtype)
return [x, r, s, l, w]
[docs]
class PairedDataset(Dataset):
r"""
A pair of interventional datasets that are paired cell-by-cell
Parameters
----------
pri
Primary interventional dataset
sec
Secondary interventional dataset
"""
def __init__(self, pri: SimpleDataset, sec: SimpleDataset) -> None:
if len(pri) != len(sec):
raise ValueError("Datasets must have the same size")
self.pri = pri
self.sec = sec
def __len__(self) -> int:
return len(self.pri)
def __getitem__(self, index: int) -> list[torch.Tensor]:
return [*self.pri[index], *self.sec[index]]
[docs]
class DynamicPairedDataset(Dataset):
r"""
A pair of interventional datasets that are not paired but fetches randomly
paired cells on-the-fly
Parameters
----------
pri
Primary interventional dataset
sec
Secondary interventional dataset
pri_strat
Stratification of the primary dataset
sec_strat
Stratification of the secondary dataset
random_state
Random state
"""
def __init__(
self,
pri: SimpleDataset,
sec: SimpleDataset,
pri_strat: np.ndarray,
sec_strat: np.ndarray,
random_state: RandomState,
) -> None:
self.pri = pri
self.sec = sec
pri_strat_set = set(pri_strat)
sec_strat_set = set(sec_strat)
if pri_strat_set != sec_strat_set:
raise ValueError("Primary and secondary stratifications do not match")
lut = {s: np.where(sec_strat == s)[0] for s in pri_strat_set}
self.lut = {i: lut[s] for i, s in enumerate(pri_strat)}
self.rnd = get_random_state(random_state)
def __len__(self) -> int:
return len(self.pri)
def __getitem__(self, index: int) -> list[torch.Tensor]:
i, j = index, self.rnd.choice(self.lut[index])
return [*self.pri[i], *self.sec[j]]
[docs]
class DataModule(LightningDataModule):
r"""
Abstract data module
Parameters
----------
batch_size
Batch size
pin_memory
Whether to use pin memory
val_frac
Fraction of validation data
random_state
Random state
"""
def __init__(
self,
batch_size: int,
pin_memory: bool,
val_frac: float,
random_state: RandomState,
) -> None:
super().__init__()
self.batch_size = batch_size
self.pin_memory = pin_memory
self.val_frac = val_frac
self.rnd = get_random_state(random_state)
self.train = None
self.val = None
self.predict = None
def __len__(self) -> int:
raise NotImplementedError # pragma: no cover
@internal
def setup(self, stage: str) -> None:
raise NotImplementedError # pragma: no cover
@internal
def train_dataloader(self):
return DataLoader(
self.train,
batch_size=self.batch_size,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=self.pin_memory,
drop_last=True,
persistent_workers=config.PERSISTENT_WORKERS,
)
@internal
def val_dataloader(self):
return DataLoader(
self.val,
batch_size=self.batch_size,
shuffle=False,
num_workers=config.NUM_WORKERS,
pin_memory=self.pin_memory,
drop_last=False,
persistent_workers=config.PERSISTENT_WORKERS,
)
@internal
def predict_dataloader(self):
return DataLoader(
self.predict,
batch_size=self.batch_size,
shuffle=False,
num_workers=config.NUM_WORKERS,
pin_memory=self.pin_memory,
drop_last=False,
persistent_workers=config.PERSISTENT_WORKERS,
)
[docs]
class SimpleDataModule(DataModule):
r"""
Simple data module using :class:`SimpleDataset`
Parameters
----------
adata
Interventional dataset
batch_size
Batch size
pin_memory
Whether to use pin memory
val_frac
Fraction of validation data
random_state
Random state
"""
def __init__(
self,
adata: AnnData,
batch_size: int,
pin_memory: bool,
val_frac: float,
random_state: RandomState,
) -> None:
super().__init__(batch_size, pin_memory, val_frac, random_state)
self.adata = adata
if 0 < self.val_frac < 1:
if adata.n_obs == 1:
self.train_idx = self.val_idx = np.arange(adata.n_obs)
elif adata.n_obs > 1:
self.train_idx, self.val_idx = train_test_split(
np.arange(adata.n_obs),
test_size=self.val_frac,
random_state=self.rnd,
)
else:
raise ValueError("Dataset cannot be empty")
elif self.val_frac == 0:
self.train_idx = np.arange(self.adata.n_obs)
self.val_idx = np.empty(0, dtype=self.train_idx.dtype)
else:
raise ValueError("Invalid validation fraction") # pragma: no cover
def __len__(self) -> int:
return self.adata.n_obs
@internal
def setup(self, stage: str) -> None:
if stage == "fit":
if self.train is None:
self.train = SimpleDataset(self.adata[self.train_idx])
if self.val is None:
self.val = SimpleDataset(self.adata[self.val_idx])
if stage == "predict" and self.predict is None:
self.predict = SimpleDataset(self.adata)
[docs]
class PairedDataModule(DataModule):
r"""
Paired data module using :class:`PairedDataset`
Parameters
----------
pri
Primary interventional dataset
sec
Secondary interventional dataset
batch_size
Batch size
pin_memory
Whether to use pin memory
val_frac
Fraction of validation data
random_state
Random state
"""
def __init__(
self,
pri: AnnData,
sec: AnnData,
batch_size: int,
pin_memory: bool,
val_frac: float,
random_state: RandomState,
) -> None:
super().__init__(batch_size, pin_memory, val_frac, random_state)
if pri.n_obs != sec.n_obs:
raise ValueError("Datasets must have the same size")
self.pri = pri
self.sec = sec
if 0 < self.val_frac < 1:
if self.pri.n_obs == 1:
self.train_idx = self.val_idx = np.arange(self.pri.n_obs)
elif self.pri.n_obs > 1:
self.train_idx, self.val_idx = train_test_split(
np.arange(self.pri.n_obs),
test_size=self.val_frac,
random_state=self.rnd,
)
else:
raise ValueError("Datasets cannot be empty")
elif self.val_frac == 0:
self.train_idx = np.arange(self.pri.n_obs)
self.val_idx = np.empty(0, dtype=self.train_idx.dtype)
else:
raise ValueError("Invalid validation fraction") # pragma: no cover
def __len__(self) -> int:
return self.pri.n_obs
@internal
def setup(self, stage: str) -> None:
if stage == "fit":
if self.train is None:
self.train = PairedDataset(
SimpleDataset(self.pri[self.train_idx]),
SimpleDataset(self.sec[self.train_idx]),
)
if self.val is None:
self.val = PairedDataset(
SimpleDataset(self.pri[self.val_idx]),
SimpleDataset(self.sec[self.val_idx]),
)
if stage == "predict" and self.predict is None:
self.predict = PairedDataset(
SimpleDataset(self.pri), SimpleDataset(self.sec)
)
[docs]
class DynamicPairedDataModule(DataModule):
r"""
Dynamic paired data module using :class:`DynamicPairedDataset`
Parameters
----------
pri
Primary interventional dataset
sec
Secondary interventional dataset
stratify
Column in :attr:`~anndata.AnnData.obs` used for stratification
batch_size
Batch size
pin_memory
Whether to use pin memory
val_frac
Fraction of validation data
random_state
Random state
"""
def __init__(
self,
pri: AnnData,
sec: AnnData,
stratify: str | None,
batch_size: int,
pin_memory: bool,
val_frac: float,
random_state: RandomState,
) -> None:
super().__init__(batch_size, pin_memory, val_frac, random_state)
self.pri = pri
self.sec = sec
self.pri_strat = (
pri.obs[stratify].to_numpy() if stratify else np.zeros(pri.n_obs)
)
self.sec_strat = (
sec.obs[stratify].to_numpy() if stratify else np.zeros(sec.n_obs)
)
if 0 < self.val_frac < 1:
if self.pri.n_obs == 1:
self.pri_train_idx = self.pri_val_idx = np.arange(self.pri.n_obs)
elif self.pri.n_obs > 1:
self.pri_train_idx, self.pri_val_idx = train_test_split(
np.arange(self.pri.n_obs),
test_size=self.val_frac,
random_state=self.rnd,
stratify=self.pri_strat,
)
else: # self.pri.n_obs == 0
raise ValueError("Primary dataset cannot be empty")
if self.sec.n_obs == 1:
self.sec_train_idx = self.sec_val_idx = np.arange(self.sec.n_obs)
elif self.sec.n_obs > 1:
self.sec_train_idx, self.sec_val_idx = train_test_split(
np.arange(self.sec.n_obs),
test_size=self.val_frac,
random_state=self.rnd,
stratify=self.sec_strat,
)
else: # self.sec.n_obs == 0
raise ValueError("Secondary dataset cannot be empty")
elif self.val_frac == 0:
self.pri_train_idx = np.arange(self.pri.n_obs)
self.sec_train_idx = np.arange(self.sec.n_obs)
self.pri_val_idx = np.empty(0, dtype=self.pri_train_idx.dtype)
self.sec_val_idx = np.empty(0, dtype=self.sec_train_idx.dtype)
else:
raise ValueError("Invalid validation fraction") # pragma: no cover
def __len__(self) -> int:
return self.pri.n_obs
@internal
def setup(self, stage: str) -> None:
if stage == "fit":
if self.train is None:
self.train = DynamicPairedDataset(
SimpleDataset(self.pri[self.pri_train_idx]),
SimpleDataset(self.sec[self.sec_train_idx]),
self.pri_strat[self.pri_train_idx],
self.sec_strat[self.sec_train_idx],
random_state=self.rnd,
)
if self.val is None:
self.val = DynamicPairedDataset(
SimpleDataset(self.pri[self.pri_val_idx]),
SimpleDataset(self.sec[self.sec_val_idx]),
self.pri_strat[self.pri_val_idx],
self.sec_strat[self.sec_val_idx],
random_state=self.rnd,
)
if stage == "predict" and self.predict is None:
self.predict = DynamicPairedDataset(
SimpleDataset(self.pri),
SimpleDataset(self.sec),
self.pri_strat,
self.sec_strat,
random_state=self.rnd,
)