cascade.model.CASCADE.discover

CASCADE.discover(adata, lam=0.1, alpha=0.5, gamma=1.0, cyc_tol=1e-4, prefit=False, opt='AdamW', lr=5e-3, weight_decay=0.01, accumulate_grad_batches=1, log_adj=LogAdj.mean, batch_size=128, val_check_interval=300, val_frac=0.1, max_epochs=1000, n_devices=1, log_subdir='discover', verbose=False, **kwargs)[source]

Causal discovery

Parameters:
  • adata (AnnData) – Input dataset

  • lam (float) – Sparse penalty rate (\(\eta_\lambda\) in paper)

  • alpha (float) – Acyclic penalty rate (\(\eta_\alpha\) in paper)

  • gamma (float) – Kernel gradient rate (\(\eta_\gamma\))

  • cyc_tol (float) – Acyclic violation tolerance

  • prefit (bool) – Whether to prefit the model on covariates

  • opt (str) – Optimizer

  • lr (float) – Learning rate

  • weight_decay (float) – Weight decay

  • accumulate_grad_batches (int) – Number of batches to accumulate before optimizer step

  • log_adj (LogAdj) – Adjacency matrix logging mode (see LogAdj)

  • batch_size (int) – Batch size

  • val_check_interval (int) – Validation check interval

  • val_frac (float) – Validation fraction

  • max_epochs (int) – Maximum number of epochs

  • n_devices (int) – Number of GPU devices to use

  • log_subdir (PathLike) – Tensorboard log subdirectory (under model-wise log_dir)

  • verbose (bool) – Whether to print verbose logs

  • **kwargs – Additional keyword arguments are passed to Trainer

Return type:

None