cascade.model.CASCADE.discover
- CASCADE.discover(adata, lam=0.1, alpha=0.5, gamma=1.0, cyc_tol=1e-4, prefit=False, opt='AdamW', lr=5e-3, weight_decay=0.01, accumulate_grad_batches=1, log_adj=LogAdj.mean, batch_size=128, val_check_interval=300, val_frac=0.1, max_epochs=1000, n_devices=1, log_subdir='discover', verbose=False, **kwargs)[source]
Causal discovery
- Parameters:
adata (
AnnData) – Input datasetlam (
float) – Sparse penalty rate (\(\eta_\lambda\) in paper)alpha (
float) – Acyclic penalty rate (\(\eta_\alpha\) in paper)gamma (
float) – Kernel gradient rate (\(\eta_\gamma\))cyc_tol (
float) – Acyclic violation toleranceprefit (
bool) – Whether to prefit the model on covariatesopt (
str) – Optimizerlr (
float) – Learning rateweight_decay (
float) – Weight decayaccumulate_grad_batches (
int) – Number of batches to accumulate before optimizer steplog_adj (
LogAdj) – Adjacency matrix logging mode (seeLogAdj)batch_size (
int) – Batch sizeval_check_interval (
int) – Validation check intervalval_frac (
float) – Validation fractionmax_epochs (
int) – Maximum number of epochsn_devices (
int) – Number of GPU devices to uselog_subdir (
PathLike) – Tensorboard log subdirectory (under model-wiselog_dir)verbose (
bool) – Whether to print verbose logs**kwargs – Additional keyword arguments are passed to
Trainer
- Return type: