Source code for pyro.poutine.lift_messenger

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import TYPE_CHECKING, Callable, Dict, Set, Union

from typing_extensions import Self

from pyro import params
from pyro.distributions.distribution import Distribution
from pyro.poutine.messenger import Messenger
from pyro.poutine.util import is_validation_enabled

if TYPE_CHECKING:
    from pyro.poutine.runtime import Message


[docs]class LiftMessenger(Messenger): """ Given a stochastic function with param calls and a prior distribution, create a stochastic function where all param calls are replaced by sampling from prior. Prior should be a callable or a dict of names to callables. Consider the following Pyro program: >>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2 >>> lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)}) ``lift`` makes ``param`` statements behave like ``sample`` statements using the distributions in ``prior``. In this example, site `s` will now behave as if it was replaced with ``s = pyro.sample("s", dist.Exponential(0.3))``: >>> tr = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> tr.nodes["s"]["type"] == "sample" True >>> tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all()) False :param fn: function whose parameters will be lifted to random values :param prior: prior function in the form of a Distribution or a dict of stochastic fns :returns: ``fn`` decorated with a :class:`~pyro.poutine.lift_messenger.LiftMessenger` """ def __init__( self, prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]], ) -> None: """ :param prior: prior used to lift parameters. Prior can be of type dict, pyro.distributions, or a python stochastic fn Constructor """ super().__init__() self.prior = prior self._samples_cache: Dict[str, "Message"] = {} def __enter__(self) -> Self: self._samples_cache = {} if is_validation_enabled() and isinstance(self.prior, dict): self._param_hits: Set[str] = set() self._param_misses: Set[str] = set() return super().__enter__() def __exit__(self, *args, **kwargs) -> None: self._samples_cache = {} if is_validation_enabled() and isinstance(self.prior, dict): extra = set(self.prior) - self._param_hits if extra: warnings.warn( "pyro.module prior did not find params ['{}']. " "Did you instead mean one of ['{}']?".format( "', '".join(extra), "', '".join(self._param_misses) ) ) return super().__exit__(*args, **kwargs) def _pyro_sample(self, msg: "Message") -> None: return None def _pyro_param(self, msg: "Message") -> None: """ Overrides the `pyro.param` call with samples sampled from the distribution specified in the prior. The prior can be a pyro.distributions object or a dict of distributions keyed on the param names. If the param name does not match the name the keys in the prior, that param name is unchanged. """ assert msg["name"] is not None name = msg["name"] param_name = params.user_param_name(name) if isinstance(self.prior, dict): # prior is a dict of distributions if param_name in self.prior.keys(): msg["fn"] = self.prior[param_name] msg["args"] = msg["args"][1:] if isinstance(msg["fn"], Distribution): msg["args"] = () msg["kwargs"] = {} msg["infer"] = {} if is_validation_enabled(): self._param_hits.add(param_name) else: if is_validation_enabled(): self._param_misses.add(param_name) return None elif isinstance(self.prior, Distribution): # prior is a distribution msg["fn"] = self.prior msg["args"] = () msg["kwargs"] = {} msg["infer"] = {} elif callable(self.prior): # prior is a stochastic fn. block sample msg["stop"] = True msg["fn"] = self.prior msg["args"] = msg["args"][1:] else: raise TypeError("unreachable") msg["type"] = "sample" if name in self._samples_cache: # Multiple pyro.param statements with the same # name. Block the site and fix the value. msg["value"] = self._samples_cache[name]["value"] msg["is_observed"] = True msg["stop"] = True else: self._samples_cache[name] = msg msg["is_observed"] = False return self._pyro_sample(msg)