Source code for pyro.poutine.condition_messenger

from .messenger import Messenger
from .trace_struct import Trace


[docs]class ConditionMessenger(Messenger): """ Adds values at observe sites to condition on data and override sampling """ def __init__(self, data): """ :param data: a dict or a Trace Constructor. Doesn't do much, just stores the stochastic function and the data to condition on. """ super(ConditionMessenger, self).__init__() self.data = data def _pyro_sample(self, msg): """ :param msg: current message at a trace site. :returns: a sample from the stochastic function at the site. If msg["name"] appears in self.data, convert the sample site into an observe site whose observed value is the value from self.data[msg["name"]]. Otherwise, implements default sampling behavior with no additional effects. """ name = msg["name"] if name in self.data: if isinstance(self.data, Trace): msg["value"] = self.data.nodes[name]["value"] else: msg["value"] = self.data[name] msg["is_observed"] = msg["value"] is not None return None