Source code for pyro.poutine.condition_messenger

from __future__ import absolute_import, division, print_function

from .messenger import Messenger
from .trace_struct import Trace

[docs]class ConditionMessenger(Messenger): """ Adds values at observe sites to condition on data and override sampling """ def __init__(self, data): """ :param data: a dict or a Trace Constructor. Doesn't do much, just stores the stochastic function and the data to condition on. """ super(ConditionMessenger, self).__init__() = data def _pyro_sample(self, msg): """ :param msg: current message at a trace site. :returns: a sample from the stochastic function at the site. If msg["name"] appears in, convert the sample site into an observe site whose observed value is the value from[msg["name"]]. Otherwise, implements default sampling behavior with no additional effects. """ name = msg["name"] if name in assert not msg["is_observed"], \ "should not change values of existing observes" if isinstance(, Trace): msg["value"] =[name]["value"] else: msg["value"] =[name] msg["is_observed"] = True return None def _pyro_param(self, msg): return None