cascade.model.CASCADE.tune

CASCADE.tune(adata, tune_ctfact=False, stratify=None, opt='AdamW', lr=5e-3, weight_decay=0.01, accumulate_grad_batches=1, log_adj=LogAdj.mean, batch_size=128, val_check_interval=300, val_frac=0.1, max_epochs=1000, n_devices=1, log_subdir='tune', verbose=False, **kwargs)[source]

Fine-tune structural equations with fixed causal structure

Parameters:
  • adata (AnnData) – Input dataset

  • tune_ctfact (bool) – Whether to tune in counterfactual mode, i.e., to use randomly paired samples for counterfactual pairs for tuning.

  • stratify (str | None) – Column name in obs for stratified random pairing (only relevant when using tune_ctfact=True)

  • opt (str) – Optimizer

  • lr (float) – Learning rate

  • weight_decay (float) – Weight decay

  • accumulate_grad_batches (int) – Number of batches to accumulate before optimizer step

  • log_adj (LogAdj) – Adjacency matrix logging mode (see LogAdj)

  • batch_size (int) – Batch size

  • val_check_interval (int) – Validation check interval

  • val_frac (float) – Validation fraction

  • max_epochs (int) – Maximum number of epochs

  • n_devices (int) – Number of GPU devices to use

  • log_subdir (PathLike) – Tensorboard log subdirectory (under model-wise log_dir)

  • verbose (bool) – Whether to print verbose logs

  • **kwargs – Additional keyword arguments are passed to Trainer

Return type:

CausalNetwork