Source code for pyro.infer.reparam.loc_scale

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from pyro.distributions.util import is_identically_one, is_validation_enabled

from .reparam import Reparam


[docs]class LocScaleReparam(Reparam): """ Generic decentering reparameterizer [1] for latent variables parameterized by ``loc`` and ``scale`` (and possibly additional ``shape_params``). This reparameterization works only for latent variables, not likelihoods. [1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019) "Automatic Reparameterisation of Probabilistic Programs" https://arxiv.org/pdf/1906.03028.pdf :param float centered: optional centered parameter. If None (default) learn a per-site per-element centering parameter in ``[0,1]``. If 0, fully decenter the distribution; if 1, preserve the centered distribution unchanged. :param shape_params: list of additional parameter names to copy unchanged from the centered to decentered distribution. :type shape_params: tuple or list """ def __init__(self, centered=None, shape_params=()): assert centered is None or isinstance(centered, (float, torch.Tensor)) assert isinstance(shape_params, (tuple, list)) assert all(isinstance(name, str) for name in shape_params) if is_validation_enabled(): if isinstance(centered, float): assert 0 <= centered and centered <= 1 elif isinstance(centered, torch.Tensor): assert (0 <= centered).all() assert (centered <= 1).all() else: assert centered is None self.centered = centered self.shape_params = shape_params
[docs] def __call__(self, name, fn, obs): assert obs is None, "LocScaleReparam does not support observe statements" centered = self.centered if is_identically_one(centered): return name, fn, obs event_shape = fn.event_shape fn, event_dim = self._unwrap(fn) # Apply a partial decentering transform. params = {key: getattr(fn, key) for key in self.shape_params} if self.centered is None: centered = pyro.param("{}_centered".format(name), lambda: fn.loc.new_full(event_shape, 0.5), constraint=constraints.unit_interval) params["loc"] = fn.loc * centered params["scale"] = fn.scale ** centered decentered_fn = type(fn)(**params) # Draw decentered noise. decentered_value = pyro.sample("{}_decentered".format(name), self._wrap(decentered_fn, event_dim)) # Differentiably transform. delta = decentered_value - centered * fn.loc value = fn.loc + fn.scale.pow(1 - centered) * delta # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim).mask(False) return new_fn, value