cascade.nn.gumbel_sigmoid

cascade.nn.gumbel_sigmoid(x, tau=1.0)[source]

Straight-through Gumbel sigmoid sampler

Parameters:
  • x (Tensor) – Logit tensor

  • tau (float) – Temperature parameter

Return type:

Tensor

Returns:

Hard reparameterized samples