Source code for pyro.poutine.lift_messenger

from __future__ import absolute_import, division, print_function

from pyro import params
from pyro.distributions import Distribution

from .messenger import Messenger


[docs]class LiftMessenger(Messenger): """ Messenger which "lifts" parameters to random samples. Given a stochastic function with param calls and a prior, creates a stochastic function where all param calls are replaced by sampling from prior. Prior should be a callable or a dict of names to callables. """ def __init__(self, prior): """ :param prior: prior used to lift parameters. Prior can be of type dict, pyro.distributions, or a python stochastic fn Constructor """ super(LiftMessenger, self).__init__() self.prior = prior def _pyro_sample(self, msg): return None def _pyro_param(self, msg): """ Overrides the `pyro.param` call with samples sampled from the distribution specified in the prior. The prior can be a pyro.distributions object or a dict of distributions keyed on the param names. If the param name does not match the name the keys in the prior, that param name is unchanged. """ name = msg["name"] param_name = params.user_param_name(name) if isinstance(self.prior, dict): # prior is a dict of distributions if param_name in self.prior.keys(): msg["fn"] = self.prior[param_name] if isinstance(msg['fn'], Distribution): msg["args"] = () msg["kwargs"] = {} msg["infer"] = {} else: return None elif isinstance(self.prior, Distribution): # prior is a distribution msg["fn"] = self.prior msg["args"] = () msg["kwargs"] = {} msg["infer"] = {} elif callable(self.prior): if not isinstance(self.prior, Distribution): # prior is a stochastic fn. block sample msg["stop"] = True msg["fn"] = self.prior else: # otherwise leave as is return None msg["type"] = "sample" msg["is_observed"] = False return self._pyro_sample(msg)