Source code for pyro.poutine.guide

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union

import torch

import pyro.distributions as dist
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites, site_is_subsample

if TYPE_CHECKING:
    from pyro.distributions.torch_distribution import TorchDistributionMixin
    from pyro.poutine.runtime import Message


[docs]class GuideMessenger(TraceMessenger, ABC): """ Abstract base class for effect-based guides. Derived classes must implement the :meth:`get_posterior` method. """ def __init__(self, model: Callable) -> None: super().__init__() # Do not register model as submodule self._model = (model,) @property def model(self) -> Callable: return self._model[0] def __getstate__(self) -> Dict[str, object]: # Avoid pickling the trace. state = self.__dict__.copy() del state["trace"] return state
[docs] def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[override] """ Draws posterior samples from the guide and replays the model against those samples. :returns: A dict mapping sample site name to sample value. This includes latent, deterministic, and observed values. :rtype: dict """ self.args_kwargs = args, kwargs try: with self: self.model(*args, **kwargs) finally: del self.args_kwargs model_trace, guide_trace = self.get_traces() samples = {} for name, site in model_trace.nodes.items(): if site["type"] == "sample": assert isinstance(site["value"], torch.Tensor) samples[name] = site["value"] return samples
def _pyro_sample(self, msg: "Message") -> None: if msg["is_observed"] or site_is_subsample(msg): return if TYPE_CHECKING: assert isinstance(msg["name"], str) assert isinstance(msg["fn"], TorchDistributionMixin) assert msg["infer"] is not None prior = msg["fn"] msg["infer"]["prior"] = prior posterior = self.get_posterior(msg["name"], prior) if isinstance(posterior, torch.Tensor): posterior = dist.Delta(posterior, event_dim=prior.event_dim) if posterior.batch_shape != prior.batch_shape: posterior = posterior.expand(prior.batch_shape) msg["fn"] = posterior def _pyro_post_sample(self, msg: "Message") -> None: # Manually apply outer plates. assert msg["infer"] is not None prior = msg["infer"].get("prior") if prior is not None: if TYPE_CHECKING: assert isinstance(msg["fn"], TorchDistributionMixin) if prior.batch_shape != msg["fn"].batch_shape: msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape) return super()._pyro_post_sample(msg)
[docs] @abstractmethod def get_posterior( self, name: str, prior: "TorchDistributionMixin" ) -> Union["TorchDistributionMixin", torch.Tensor]: """ Abstract method to compute a posterior distribution or sample a posterior value given a prior distribution conditioned on upstream posterior samples. Implementations may use ``pyro.param`` and ``pyro.sample`` inside this function, but ``pyro.sample`` statements should set ``infer={"is_auxiliary": True"}`` . Implementations may access further information for computations: - ``value = self.upstream_value(name)`` is the value of an upstream sample or deterministic site. - ``self.trace`` is a trace of upstream sites, and may be useful for other information such as ``self.trace.nodes["my_site"]["fn"]`` or ``self.trace.nodes["my_site"]["cond_indep_stack"]`` . - ``args, kwargs = self.args_kwargs`` are the inputs to the model, and may be useful for amortization. :param str name: The name of the sample site to sample. :param prior: The prior distribution of this sample site (conditioned on upstream samples from the posterior). :type prior: ~pyro.distributions.Distribution :returns: A posterior distribution or sample from the posterior distribution. :rtype: ~pyro.distributions.Distribution or torch.Tensor """ raise NotImplementedError
[docs] def upstream_value(self, name: str) -> Optional[torch.Tensor]: """ For use in :meth:`get_posterior` . :returns: The value of an upstream sample or deterministic site :rtype: torch.Tensor """ return self.trace.nodes[name]["value"]
[docs] def get_traces(self) -> Tuple[Trace, Trace]: """ This can be called after running :meth:`__call__` to extract a pair of traces. In contrast to the trace-replay pattern of generating a pair of traces, :class:`GuideMessenger` interleaves model and guide computations, so only a single ``guide(*args, **kwargs)`` call is needed to create both traces. This function merely extract the relevant information from this guide's ``.trace`` attribute. :returns: a pair ``(model_trace, guide_trace)`` :rtype: tuple """ guide_trace = prune_subsample_sites(self.trace) model_trace = model_trace = guide_trace.copy() for name, guide_site in list(guide_trace.nodes.items()): if guide_site["type"] != "sample" or guide_site["is_observed"]: del guide_trace.nodes[name] continue model_site = model_trace.nodes[name].copy() assert guide_site["infer"] is not None model_site["fn"] = guide_site["infer"]["prior"] model_trace.nodes[name] = model_site return model_trace, guide_trace