cascade.model.CASCADE.tune
- CASCADE.tune(adata, tune_ctfact=False, stratify=None, opt='AdamW', lr=5e-3, weight_decay=0.01, accumulate_grad_batches=1, log_adj=LogAdj.mean, batch_size=128, val_check_interval=300, val_frac=0.1, max_epochs=1000, n_devices=1, log_subdir='tune', verbose=False, **kwargs)[source]
Fine-tune structural equations with fixed causal structure
- Parameters:
adata (
AnnData) – Input datasettune_ctfact (
bool) – Whether to tune in counterfactual mode, i.e., to use randomly paired samples for counterfactual pairs for tuning.stratify (
str|None) – Column name inobsfor stratified random pairing (only relevant when usingtune_ctfact=True)opt (
str) – Optimizerlr (
float) – Learning rateweight_decay (
float) – Weight decayaccumulate_grad_batches (
int) – Number of batches to accumulate before optimizer steplog_adj (
LogAdj) – Adjacency matrix logging mode (seeLogAdj)batch_size (
int) – Batch sizeval_check_interval (
int) – Validation check intervalval_frac (
float) – Validation fractionmax_epochs (
int) – Maximum number of epochsn_devices (
int) – Number of GPU devices to uselog_subdir (
PathLike) – Tensorboard log subdirectory (under model-wiselog_dir)verbose (
bool) – Whether to print verbose logs**kwargs – Additional keyword arguments are passed to
Trainer
- Return type: