import warnings
from pyro import params
from pyro.distributions.distribution import Distribution
from pyro.poutine.util import is_validation_enabled
from .messenger import Messenger
[docs]class LiftMessenger(Messenger):
"""
Messenger which "lifts" parameters to random samples.
Given a stochastic function with param calls and a prior,
creates a stochastic function where all param calls are
replaced by sampling from prior.
Prior should be a callable or a dict of names to callables.
"""
def __init__(self, prior):
"""
:param prior: prior used to lift parameters. Prior can be of type
dict, pyro.distributions, or a python stochastic fn
Constructor
"""
super(LiftMessenger, self).__init__()
self.prior = prior
self._samples_cache = {}
def __enter__(self):
self._samples_cache = {}
if is_validation_enabled() and isinstance(self.prior, dict):
self._param_hits = set()
self._param_misses = set()
return super(LiftMessenger, self).__enter__()
def __exit__(self, *args, **kwargs):
self._samples_cache = {}
if is_validation_enabled() and isinstance(self.prior, dict):
extra = set(self.prior) - self._param_hits
if extra:
warnings.warn(
"pyro.module prior did not find params ['{}']. "
"Did you instead mean one of ['{}']?"
.format("', '".join(extra), "', '".join(self._param_misses)))
return super(LiftMessenger, self).__exit__(*args, **kwargs)
def _pyro_sample(self, msg):
return None
def _pyro_param(self, msg):
"""
Overrides the `pyro.param` call with samples sampled from the
distribution specified in the prior. The prior can be a
pyro.distributions object or a dict of distributions keyed
on the param names. If the param name does not match the
name the keys in the prior, that param name is unchanged.
"""
name = msg["name"]
param_name = params.user_param_name(name)
if isinstance(self.prior, dict):
# prior is a dict of distributions
if param_name in self.prior.keys():
msg["fn"] = self.prior[param_name]
msg["args"] = msg["args"][1:]
if isinstance(msg['fn'], Distribution):
msg["args"] = ()
msg["kwargs"] = {}
msg["infer"] = {}
if is_validation_enabled():
self._param_hits.add(param_name)
else:
if is_validation_enabled():
self._param_misses.add(param_name)
return None
elif isinstance(self.prior, Distribution):
# prior is a distribution
msg["fn"] = self.prior
msg["args"] = ()
msg["kwargs"] = {}
msg["infer"] = {}
elif callable(self.prior):
if not isinstance(self.prior, Distribution):
# prior is a stochastic fn. block sample
msg["stop"] = True
msg["fn"] = self.prior
msg["args"] = msg["args"][1:]
else:
# otherwise leave as is
return None
msg["type"] = "sample"
if name in self._samples_cache:
# Multiple pyro.param statements with the same
# name. Block the site and fix the value.
msg['value'] = self._samples_cache[name]['value']
msg["is_observed"] = True
msg["stop"] = True
else:
self._samples_cache[name] = msg
msg["is_observed"] = False
return self._pyro_sample(msg)