Source code for pyro.distributions.transforms.cholesky

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import math
import warnings

import torch
from torch.distributions.transforms import CorrCholeskyTransform, Transform

from .. import constraints


class CorrLCholeskyTransform(CorrCholeskyTransform):  # DEPRECATED
    def __init__(self, cache_size=0):
        warnings.warn(
            "class CorrLCholeskyTransform is deprecated in favor of CorrCholeskyTransform.",
            FutureWarning,
        )
        super().__init__(cache_size=cache_size)


[docs]class CholeskyTransform(Transform): r""" Transform via the mapping :math:`y = safe_cholesky(x)`, where `x` is a positive definite matrix. """ bijective = True domain = constraints.positive_definite codomain = constraints.lower_cholesky def __eq__(self, other): return isinstance(other, CholeskyTransform) def _call(self, x): return torch.linalg.cholesky(x) def _inverse(self, y): return torch.matmul(y, torch.transpose(y, -2, -1))
[docs] def log_abs_det_jacobian(self, x, y): # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13 n = x.shape[-1] order = torch.arange(n, 0, -1, dtype=x.dtype, device=x.device) return -n * math.log(2) - ( order * torch.diagonal(y, dim1=-2, dim2=-1).log() ).sum(-1)
[docs]class CorrMatrixCholeskyTransform(CholeskyTransform): r""" Transform via the mapping :math:`y = safe_cholesky(x)`, where `x` is a correlation matrix. """ bijective = True domain = constraints.corr_matrix # TODO: change corr_cholesky_constraint to corr_cholesky when the latter is availabler codomain = constraints.corr_cholesky_constraint def __eq__(self, other): return isinstance(other, CorrMatrixCholeskyTransform)
[docs] def log_abs_det_jacobian(self, x, y): # NB: see derivation in LKJCholesky implementation n = x.shape[-1] order = torch.arange(n - 1, -1, -1, dtype=x.dtype, device=x.device) return -(order * torch.diagonal(y, dim1=-2, dim2=-1).log()).sum(-1)