Source code for pyro.distributions.transforms.unit_cholesky

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints
from torch.distributions.transforms import Transform

from pyro.distributions.constraints import unit_lower_cholesky

[docs]class UnitLowerCholeskyTransform(Transform): """ Transform from unconstrained matrices to lower-triangular matrices with all ones diagonals. """ domain = constraints.independent(constraints.real, 2) codomain = unit_lower_cholesky def __eq__(self, other): return isinstance(other, UnitLowerCholeskyTransform) def _call(self, x): return x.tril(-1) + torch.eye(x.size(-1), device=x.device, dtype=x.dtype) def _inverse(self, y): return y
__all__ = [ "UnitLowerCholeskyTransform", ]