# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Dict, Optional
from pyro.poutine.messenger import Messenger
import torch
from pyro.poutine.runtime import Message
from pyro.poutine.trace_struct import Trace
[docs]class ReplayMessenger(Messenger):
Given a callable that contains Pyro primitive calls,
return a callable that runs the original, reusing the values at sites in trace
at those sites in the new trace
Consider the following Pyro program:
>>> def model(x):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(x, s))
... return z ** 2
``replay`` makes ``sample`` statements behave as if they had sampled the values
at the corresponding sites in the trace:
>>> old_trace = pyro.poutine.trace(model).get_trace(1.0)
>>> replayed_model = pyro.poutine.replay(model, trace=old_trace)
>>> bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"])
:param fn: a stochastic function (callable containing Pyro primitive calls)
:param trace: a :class:`~pyro.poutine.Trace` data structure to replay against
:param params: dict of names of param sites and constrained values
in fn to replay against
:returns: a stochastic function decorated with a :class:`~pyro.poutine.replay_messenger.ReplayMessenger`
def __init__(
trace: Optional["Trace"] = None,
params: Optional[Dict[str, "torch.Tensor"]] = None,
) -> None:
:param trace: a trace whose values should be reused
Stores trace in an attribute.
if trace is None and params is None:
raise ValueError("must provide trace or params to replay against")
self.trace = trace
self.params = params
def _pyro_sample(self, msg: "Message") -> None:
:param msg: current message at a trace site.
At a sample site that appears in self.trace,
returns the value from self.trace instead of sampling
from the stochastic function at the site.
At a sample site that does not appear in self.trace,
reverts to default Messenger._pyro_sample behavior with no additional side effects.
assert msg["name"] is not None
name = msg["name"]
if self.trace is not None and name in self.trace:
guide_msg = self.trace.nodes[name]
if msg["is_observed"]:
return None
if guide_msg["type"] != "sample" or guide_msg["is_observed"]:
raise RuntimeError("site {} must be sampled in trace".format(name))
msg["done"] = True
msg["value"] = guide_msg["value"]
msg["infer"] = guide_msg["infer"]
def _pyro_param(self, msg: "Message") -> None:
name = msg["name"]
if self.params is not None and name in self.params:
assert hasattr(
self.params[name], "unconstrained"
), "param {} must be constrained value".format(name)
msg["done"] = True
msg["value"] = self.params[name]