Source code for pyro.poutine.reparam_messenger

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import (
    TYPE_CHECKING,
    Callable,
    Dict,
    Generic,
    List,
    Optional,
    TypeVar,
    Union,
)

import torch
from typing_extensions import ParamSpec

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import effectful

if TYPE_CHECKING:
    from pyro.distributions.torch_distribution import TorchDistributionMixin
    from pyro.infer.reparam.reparam import Reparam
    from pyro.poutine.runtime import Message

_P = ParamSpec("_P")
_T = TypeVar("_T")


@effectful(type="get_init_messengers")
def _get_init_messengers() -> List[Messenger]:
    return []


[docs]class ReparamMessenger(Messenger): """ Reparametrizes each affected sample site into one or more auxiliary sample sites followed by a deterministic transformation [1]. To specify reparameterizers, pass a ``config`` dict or callable to the constructor. See the :mod:`pyro.infer.reparam` module for available reparameterizers. Note some reparameterizers can examine the ``*args,**kwargs`` inputs of functions they affect; these reparameterizers require using ``poutine.reparam`` as a decorator rather than as a context manager. [1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019) "Automatic Reparameterisation of Probabilistic Programs" https://arxiv.org/pdf/1906.03028.pdf :param config: Configuration, either a dict mapping site name to :class:`~pyro.infer.reparam.reparam.Reparameterizer` , or a function mapping site to :class:`~pyro.infer.reparam.reparam.Reparam` or None. See :mod:`pyro.infer.reparam.strategies` for built-in configuration strategies. :type config: dict or callable """ def __init__( self, config: Union[Dict[str, "Reparam"], Callable[["Message"], Optional["Reparam"]]], ) -> None: super().__init__() assert isinstance(config, dict) or callable(config) self.config = config self._args_kwargs = None def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]: return ReparamHandler(self, fn) def _pyro_sample(self, msg: "Message") -> None: if type(msg["fn"]).__name__ == "_Subsample": return assert msg["name"] is not None if TYPE_CHECKING: assert isinstance(msg["fn"], TorchDistributionMixin) if isinstance(self.config, dict): reparam = self.config.get(msg["name"]) else: reparam = self.config(msg) if reparam is None: return # See https://github.com/pyro-ppl/pyro/issues/2878 # This is a tricky hack to apply messengers in an order other than the # standard _PYRO_STACK order. Our goal is to allow (model, initializer) # pairs to be reparametrized as a unit. The problem is that messengers # are typically applied in the order # # InitMessenger(init_to_value(...))(ReparamMessenger(...)(model)) # # so that original model sites are reparametrized by the time they are # seen by init_to_value(). To work around this we allow # ReparamMessenger to apply enclosing InitMessengers early, simulating # a priority system for messengers (indeed we might consider # prioritizing messengers). Note that the enclosing InitMessenger will be # called a second time, after ReparamMessenger, but that is ok because # InitMessenger does not overwrite values. # # To get this same logic to work for ConditionMessenger or # ReplayMessenger we would need to ensure those messengers can # similarly be safely applied twice, with the second application # avoiding overwriting the original application. _get_init_messengers_iter = _get_init_messengers() assert _get_init_messengers_iter is not None for m in _get_init_messengers_iter: m._process_message(msg) # Pass args_kwargs to the reparam via a side channel. reparam.args_kwargs = self._args_kwargs # type: ignore[attr-defined] try: new_msg = reparam.apply( { "name": msg["name"], "fn": msg["fn"], "value": msg["value"], "is_observed": msg["is_observed"], } ) finally: reparam.args_kwargs = None # type: ignore[attr-defined] if new_msg["value"] is not None: # Validate while the original msg["fn"] is known. if getattr(msg["fn"], "_validation_enabled", False): msg["fn"]._validate_sample(new_msg["value"]) if msg["value"] is not None and msg["value"] is not new_msg["value"]: # Check that overwritten initialization preserves shape. if not torch._C._get_tracing_state(): assert new_msg["value"].shape == msg["value"].shape # Warn if a custom init method is overwritten by another init method. if getattr(msg["value"], "_pyro_custom_init", True): warnings.warn( f"At pyro.sample({repr(msg['name'])},...), " f"{type(reparam).__name__} " "does not commute with initialization; " "falling back to default initialization.", RuntimeWarning, ) msg["fn"] = new_msg["fn"] msg["value"] = new_msg["value"] msg["is_observed"] = new_msg["is_observed"]
[docs]class ReparamHandler(Generic[_P, _T]): """ Reparameterization poutine. """ def __init__(self, msngr, fn: Callable[_P, _T]) -> None: self.msngr = msngr self.fn = fn super().__init__() def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: # This saves args,kwargs for optional use by reparameterizers. self.msngr._args_kwargs = args, kwargs try: with self.msngr: return self.fn(*args, **kwargs) finally: self.msngr._args_kwargs = None