Source code for pyro.distributions.transforms.basic

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

import torch
import torch.nn.functional as F
from torch.distributions.transforms import TanhTransform, Transform

from .. import constraints

# TODO: Move upstream


[docs]class ELUTransform(Transform): r""" Bijective transform via the mapping :math:`y = \text{ELU}(x)`. """ domain = constraints.real codomain = constraints.positive bijective = True sign = +1 def __eq__(self, other): return isinstance(other, ELUTransform) def _call(self, x): return F.elu(x) def _inverse(self, y, eps=1e-8): return torch.max(y, torch.zeros_like(y)) + torch.min( torch.log1p(y + eps), torch.zeros_like(y) )
[docs] def log_abs_det_jacobian(self, x, y): return -F.relu(-x)
[docs]def elu(): """ A helper function to create an :class:`~pyro.distributions.transform.ELUTransform` object for consistency with other helpers. """ return ELUTransform()
# TODO: Move upstream
[docs]class LeakyReLUTransform(Transform): r""" Bijective transform via the mapping :math:`y = \text{LeakyReLU}(x)`. """ domain = constraints.real codomain = constraints.real bijective = True sign = +1 def __eq__(self, other): return isinstance(other, LeakyReLUTransform) def _call(self, x): return F.leaky_relu(x) def _inverse(self, y): return F.leaky_relu(y, negative_slope=100.0)
[docs] def log_abs_det_jacobian(self, x, y): return torch.where( x >= 0.0, torch.zeros_like(x), torch.ones_like(x) * math.log(0.01) )
[docs]def leaky_relu(): """ A helper function to create a :class:`~pyro.distributions.transforms.LeakyReLUTransform` object for consistency with other helpers. """ return LeakyReLUTransform()
def tanh(): """ A helper function to create a :class:`~pyro.distributions.transforms.TanhTransform` object for consistency with other helpers. """ return TanhTransform()