Source code for pyro.infer.reparam.transform

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import pyro
import pyro.distributions as dist

from .reparam import Reparam

[docs]class TransformReparam(Reparam): """ Reparameterizer for :class:`pyro.distributions.torch.TransformedDistribution` latent variables. This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior has simple shape in the space of ``base_dist``. This reparameterization works only for latent variables, not likelihoods. """
[docs] def __call__(self, name, fn, obs): assert obs is None, "TransformReparam does not support observe statements" assert isinstance(fn, dist.TransformedDistribution) # Draw noise from the base distribution. x = pyro.sample("{}_base".format(name), fn.base_dist) # Differentiably transform. for t in fn.transforms: x = t(x) # Simulate a pyro.deterministic() site. new_fn = dist.Delta(x, event_dim=fn.event_dim) return new_fn, x