cascade.core.CausalNetwork.predict_step

CausalNetwork.predict_step(batch, batch_idx)[source]

Prediction step for a minibatch

Return type:

tuple[Tensor, Tensor, Tensor] | Tensor