Source code for pyro.infer.reparam.transform

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch

import pyro
import pyro.distributions as dist

from .reparam import Reparam

[docs]class TransformReparam(Reparam): """ Reparameterizer for :class:`pyro.distributions.torch.TransformedDistribution` latent variables. This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior has simple shape in the space of ``base_dist``. This reparameterization works only for latent variables, not likelihoods. """
[docs] def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] fn, event_dim = self._unwrap(fn) assert isinstance(fn, torch.distributions.TransformedDistribution) # Differentiably invert transform. value_base = value if value is not None: for t in reversed(fn.transforms): value_base = t.inv(value_base) # Draw noise from the base distribution. base_event_dim = event_dim for t in reversed(fn.transforms): base_event_dim += t.domain.event_dim - t.codomain.event_dim value_base = pyro.sample( f"{name}_base", self._wrap(fn.base_dist, base_event_dim), obs=value_base, infer={"is_observed": is_observed}, ) # Differentiably transform. if value is None: value = value_base for t in fn.transforms: value = t(value) # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim).mask(False) return {"fn": new_fn, "value": value, "is_observed": True}