# Source code for pyro.distributions.transforms.softplus

# Copyright Contributors to the Pyro project.

from torch.distributions import constraints
from torch.distributions.transforms import Transform
from torch.nn.functional import softplus

def softplus_inv(y):
return y + y.neg().expm1().neg().log()

# Backport of https://github.com/pytorch/pytorch/pull/52300
[docs]class SoftplusTransform(Transform):
r"""
Transform via the mapping :math:\text{Softplus}(x) = \log(1 + \exp(x)).
"""
domain = constraints.real
codomain = constraints.positive
bijective = True
sign = +1

def __eq__(self, other):
return isinstance(other, SoftplusTransform)

def _call(self, x):
return softplus(x)

def _inverse(self, y):
return softplus_inv(y)

[docs]    def log_abs_det_jacobian(self, x, y):
return -softplus(-x)

[docs]class SoftplusLowerCholeskyTransform(Transform):
"""
Transform from unconstrained matrices to lower-triangular matrices with
nonnegative diagonal entries. This is useful for parameterizing positive
definite matrices in terms of their Cholesky factorization.
"""

domain = constraints.independent(constraints.real, 2)
codomain = constraints.lower_cholesky

def __eq__(self, other):
return isinstance(other, SoftplusLowerCholeskyTransform)

def _call(self, x):
diag = softplus(x.diagonal(dim1=-2, dim2=-1))
return x.tril(-1) + diag.diag_embed()

def _inverse(self, y):
diag = softplus_inv(y.diagonal(dim1=-2, dim2=-1))
return y.tril(-1) + diag.diag_embed()

__all__ = [
"SoftplusTransform",
"SoftplusLowerCholeskyTransform",
]