cascade.nn.multi_trace

cascade.nn.multi_trace(m)[source]

Compute matrix trace with support for multiplex dims

Parameters:

m (Tensor) – Matrix of shape (*m, n_vars, n_vars)

Return type:

Tensor

Returns:

Matrix trace of shape (*m,)