Source code for pyro.distributions.transforms.basic
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import math
import torch
import torch.nn.functional as F
from torch.distributions.transforms import TanhTransform, Transform
from .. import constraints
# TODO: Move upstream
[docs]class ELUTransform(Transform):
r"""
Bijective transform via the mapping :math:`y = \text{ELU}(x)`.
"""
domain = constraints.real
codomain = constraints.positive
bijective = True
sign = +1
def __eq__(self, other):
return isinstance(other, ELUTransform)
def _call(self, x):
return F.elu(x)
def _inverse(self, y, eps=1e-8):
return torch.max(y, torch.zeros_like(y)) + torch.min(
torch.log1p(y + eps), torch.zeros_like(y)
)
[docs]def elu():
"""
A helper function to create an
:class:`~pyro.distributions.transform.ELUTransform` object for consistency with
other helpers.
"""
return ELUTransform()
# TODO: Move upstream
[docs]class LeakyReLUTransform(Transform):
r"""
Bijective transform via the mapping :math:`y = \text{LeakyReLU}(x)`.
"""
domain = constraints.real
codomain = constraints.real
bijective = True
sign = +1
def __eq__(self, other):
return isinstance(other, LeakyReLUTransform)
def _call(self, x):
return F.leaky_relu(x)
def _inverse(self, y):
return F.leaky_relu(y, negative_slope=100.0)
[docs] def log_abs_det_jacobian(self, x, y):
return torch.where(
x >= 0.0, torch.zeros_like(x), torch.ones_like(x) * math.log(0.01)
)
[docs]def leaky_relu():
"""
A helper function to create a
:class:`~pyro.distributions.transforms.LeakyReLUTransform` object for
consistency with other helpers.
"""
return LeakyReLUTransform()
def tanh():
"""
A helper function to create a
:class:`~pyro.distributions.transforms.TanhTransform` object for consistency
with other helpers.
"""
return TanhTransform()