Source code for pyro.infer.reparam.unit_jacobian

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from contextlib import ExitStack

from torch.distributions import biject_to
from torch.distributions.transforms import ComposeTransform

import pyro
import pyro.distributions as dist
from pyro.poutine.plate_messenger import block_plate

from .reparam import Reparam


[docs]class UnitJacobianReparam(Reparam): """ Reparameterizer for :class:`~torch.distributions.transforms.Transform` objects whose Jacobian determinant is one. :param transform: A transform whose Jacobian has determinant 1. :type transform: ~torch.distributions.transforms.Transform :param str suffix: A suffix to append to the transformed site. :param bool experimental_allow_batch: EXPERIMENTAL allow coupling across a batch dimension. The targeted batch dimension and all batch dimensions to the right will be converted to event dimensions. Defaults to False. """ def __init__( self, transform, suffix="transformed", *, experimental_allow_batch=False ): self.transform = transform.with_cache() self.suffix = suffix self.experimental_allow_batch = experimental_allow_batch
[docs] def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] event_dim = fn.event_dim transform = self.transform with ExitStack() as stack: shift = max(0, transform.event_dim - event_dim) if shift: if not self.experimental_allow_batch: raise ValueError( "Cannot transform along batch dimension; try either" "converting a batch dimension to an event dimension, or " "setting experimental_allow_batch=True." ) # Reshape and mute plates using block_plate. from pyro.contrib.forecast.util import ( reshape_batch, reshape_transform_batch, ) old_shape = fn.batch_shape new_shape = old_shape[:-shift] + (1,) * shift + old_shape[-shift:] fn = reshape_batch(fn, new_shape).to_event(shift) transform = reshape_transform_batch( transform, old_shape + fn.event_shape, new_shape + fn.event_shape ) if value is not None: value = value.reshape( value.shape[: -shift - event_dim] + (1,) * shift + value.shape[-shift - event_dim :] ) for dim in range(-shift, 0): stack.enter_context(block_plate(dim=dim, strict=False)) # Differentiably invert transform. transform = ComposeTransform( [biject_to(fn.support).inv.with_cache(), self.transform] ) value_trans = None if value is not None: value_trans = transform(value) # Draw noise from the base distribution. value_trans = pyro.sample( f"{name}_{self.suffix}", dist.TransformedDistribution(fn, transform), obs=value_trans, infer={"is_observed": is_observed}, ) # Differentiably transform. This should be free due to transform cache. if value is None: value = transform.inv(value_trans) if shift: value = value.reshape( value.shape[: -2 * shift - event_dim] + value.shape[-shift - event_dim :] ) # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim) return {"fn": new_fn, "value": value, "is_observed": True}