cascade.nn.IntervDesign

class cascade.nn.IntervDesign(n_vars, k, design_scale_bias, mask, interv_scale, interv_bias, target_weight)[source]

Bases: Module

Intervention design module

Parameters:
  • n_vars (int) – Number of variables

  • k (int) – Maximal combinatorial order to consider

  • design_scale_bias (bool) – Whether to optimize the intervention scale and bias

  • mask (BoolTensor) – Boolean mask that marks variables in the design candidate pool

  • interv_scale (Tensor) – Intervention scale tensor trained in the discover phase

  • interv_bias (Tensor) – Intervention bias tensor trained in the discover phase

  • target_weight (Tensor) – Variable weight when computing target deviation

Methods

load

Load design module from file

loss

rsample

save

Save the design module to file

simplex2regime

Attributes

bias

comb_lists

scale