Source code for pyro.poutine.reparam_messenger

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from .messenger import Messenger


[docs]class ReparamMessenger(Messenger): """ Reparametrizes each affected sample site into one or more auxiliary sample sites followed by a deterministic transformation [1]. To specify reparameterizers, pass a ``config`` dict or callable to the constructor. See the :mod:`pyro.infer.reparam` module for available reparameterizers. [1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019) "Automatic Reparameterisation of Probabilistic Programs" https://arxiv.org/pdf/1906.03028.pdf :param config: Configuration, either a dict mapping site name to :class:`~pyro.infer.reparam.reparam.Reparameterizer` , or a function mapping site to :class:`~pyro.infer.reparam.reparam.Reparameterizer` or None. :type config: dict or callable """ def __init__(self, config): super().__init__() assert isinstance(config, dict) or callable(config) self.config = config def _pyro_sample(self, msg): if isinstance(self.config, dict): reparam = self.config.get(msg["name"]) else: reparam = self.config(msg) if reparam is None: return new_fn, value = reparam(msg["name"], msg["fn"], msg["value"]) if value is not None: if msg["value"] is None: msg["is_observed"] = True msg["value"] = value if getattr(msg["fn"], "_validation_enabled", False): # Validate while the original msg["fn"] is known. msg["fn"]._validate_sample(value) msg["fn"] = new_fn