Source code for pyro.infer.reparam.transform

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import pyro
import pyro.distributions as dist

from .reparam import Reparam


[docs]class TransformReparam(Reparam): """ Reparameterizer for :class:`pyro.distributions.torch.TransformedDistribution` latent variables. This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior has simple shape in the space of ``base_dist``. This reparameterization works only for latent variables, not likelihoods. """
[docs] def __call__(self, name, fn, obs): assert obs is None, "TransformReparam does not support observe statements" fn, event_dim = self._unwrap(fn) assert isinstance(fn, dist.TransformedDistribution) # Draw noise from the base distribution. base_event_dim = event_dim try: # requires https://github.com/pyro-ppl/pyro/pull/2739 for t in reversed(fn.transforms): base_event_dim += t.domain.event_dim - t.codomain.event_dim except AttributeError: pass x = pyro.sample("{}_base".format(name), self._wrap(fn.base_dist, base_event_dim)) # Differentiably transform. for t in fn.transforms: x = t(x) # Simulate a pyro.deterministic() site. new_fn = dist.Delta(x, event_dim=event_dim).mask(False) return new_fn, x