Source code for pyro.distributions.transforms

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from torch.distributions import biject_to, transform_to
from torch.distributions.transforms import *  # noqa F403
from torch.distributions.transforms import __all__ as torch_transforms

from pyro.distributions.torch_transform import ComposeTransformModule
from pyro.distributions.transforms.affine_autoregressive import AffineAutoregressive, affine_autoregressive
from pyro.distributions.transforms.affine_coupling import (AffineCoupling, ConditionalAffineCoupling,
                                                           affine_coupling, conditional_affine_coupling)
from pyro.distributions.transforms.batchnorm import BatchNorm, batchnorm
from pyro.distributions.transforms.block_autoregressive import BlockAutoregressive, block_autoregressive
from pyro.distributions.transforms.discrete_cosine import DiscreteCosineTransform
from pyro.distributions.transforms.generalized_channel_permute import (GeneralizedChannelPermute,
                                                                       generalized_channel_permute)
from pyro.distributions.transforms.householder import Householder, householder
from pyro.distributions.transforms.lower_cholesky_affine import LowerCholeskyAffine
from pyro.distributions.transforms.neural_autoregressive import (ELUTransform, LeakyReLUTransform,
                                                                 TanhTransform, NeuralAutoregressive,
                                                                 neural_autoregressive,
                                                                 elu, leaky_relu, tanh)
from pyro.distributions.transforms.permute import Permute, permute
from pyro.distributions.transforms.polynomial import Polynomial, polynomial
from pyro.distributions.transforms.planar import Planar, ConditionalPlanar, planar, conditional_planar
from pyro.distributions.transforms.radial import Radial, ConditionalRadial, radial, conditional_radial
from pyro.distributions.transforms.spline import Spline, spline
from pyro.distributions.transforms.sylvester import Sylvester, sylvester
from pyro.distributions.constraints import IndependentConstraint, corr_cholesky_constraint
from pyro.distributions.transforms.cholesky import CorrLCholeskyTransform

########################################
# register transforms

biject_to.register(IndependentConstraint, lambda c: biject_to(c.base_constraint))
transform_to.register(IndependentConstraint, lambda c: transform_to(c.base_constraint))


@biject_to.register(corr_cholesky_constraint)
@transform_to.register(corr_cholesky_constraint)
def _transform_to_corr_cholesky(constraint):
    return CorrLCholeskyTransform()


[docs]def iterated(repeats, base_fn, *args, **kwargs): """ Helper function to compose a sequence of bijective transforms with potentially learnable parameters using :class:`~pyro.distributions.ComposeTransformModule`. :param repeats: number of repeated transforms. :param base_fn: function to construct the bijective transform. :param args: arguments taken by `base_fn`. :param kwargs: keyword arguments taken by `base_fn`. :return: instance of :class:`~pyro.distributions.TransformModule`. """ assert isinstance(repeats, int) and repeats >= 1 return ComposeTransformModule([base_fn(*args, **kwargs) for _ in range(repeats)])
__all__ = [ 'iterated', 'AffineAutoregressive', 'AffineCoupling', 'BatchNorm', 'BlockAutoregressive', 'ComposeTransformModule', 'ConditionalAffineCoupling', 'ConditionalPlanar', 'ConditionalRadial', 'CorrLCholeskyTransform', 'DiscreteCosineTransform', 'ELUTransform', 'GeneralizedChannelPermute', 'Householder', 'LeakyReLUTransform', 'LowerCholeskyAffine', 'NeuralAutoregressive', 'Permute', 'Planar', 'Polynomial', 'Radial', 'Spline', 'Sylvester', 'TanhTransform', 'affine_autoregressive', 'affine_coupling', 'batchnorm', 'block_autoregressive', 'conditional_affine_coupling', 'conditional_planar', 'conditional_radial', 'elu', 'generalized_channel_permute', 'householder', 'leaky_relu', 'neural_autoregressive', 'permute', 'planar', 'polynomial', 'radial', 'spline', 'sylvester', 'tanh', ] __all__.extend(torch_transforms) del torch_transforms