Source code for pyro.poutine.runtime

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import functools
from typing import (

import torch
from typing_extensions import Literal, ParamSpec, TypedDict

from pyro.params.param_store import (  # noqa: F401

_P = ParamSpec("_P")
_T = TypeVar("_T")

    from collections import Counter

    from pyro.distributions.score_parts import ScoreParts
    from pyro.distributions.torch_distribution import TorchDistributionMixin
    from pyro.poutine.indep_messenger import CondIndepStackFrame
    from pyro.poutine.messenger import Messenger

# the global pyro stack
_PYRO_STACK: List["Messenger"] = []

# the global ParamStore
_PYRO_PARAM_STORE = ParamStoreDict()

[docs]class InferDict(TypedDict, total=False): """ A dictionary that contains information about inference. This can be used to configure per-site inference strategies, e.g.:: pyro.sample( "x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}, ) Keys: enumerate (str): If one of the strings "sequential" or "parallel", enables enumeration. Parallel enumeration is generally faster but requires broadcasting-safe operations and static structure. expand (bool): Whether to expand the distribution during enumeration. Defaults to False if missing. is_auxiliary (bool): Whether the sample site is auxiliary, e.g. for use in guides that deterministically transform auxiliary variables. Defaults to False if missing. is_observed (bool): Whether the sample site is observed (i.e. not latent). Defaults to False if missing. num_samples (int): The number of samples to draw. Defaults to 1 if missing. obs (optional torch.Tensor): The observed value, or None for latent variables. Defaults to None if missing. prior (optional torch.distributions.Distribution): (internal) For use in GuideMessenger to store the model's prior distribution (conditioned on upstream sites). tmc (str): Whether to use the diagonal or mixture approximation for Tensor Monte Carlo in TraceTMC_ELBO. was_observed (bool): (internal) Whether the sample site was originally observed, in the context of inference via Reweighted Wake Sleep or Compiled Sequential Importance Sampling. """ enumerate: Literal["sequential", "parallel"] expand: bool is_auxiliary: bool is_observed: bool num_samples: int obs: Optional[torch.Tensor] prior: "TorchDistributionMixin" tmc: Literal["diagonal", "mixture"] was_observed: bool _deterministic: bool _dim_to_symbol: Dict[int, str] _do_not_trace: bool _enumerate_symbol: str _markov_scope: "Counter" _enumerate_dim: int _dim_to_id: Dict[int, int] _markov_depth: int
[docs]class Message(TypedDict, Generic[_P, _T], total=False): """ Pyro's internal message type for effect handling. Messages are stored in trace objects, e.g.:: trace.nodes["my_site_name"] # This is a Message. Keys: type (str): The message type, typically one of the strings "sample", "param", "plate", or "markov", but possibly custom. name (str): The site name, typically naming a sample or parameter. fn (callable): The distribution or function used to generate the sample. is_observed (bool): A flag to indicate whether the value is observed. args (tuple): Positional arguments to the distribution or function. kwargs (dict): Keyword arguments to the distribution or function. value (torch.Tensor): The value of the sample (either observed or sampled). scale (torch.Tensor): A scaling factor for the log probability. mask (bool torch.Tensor): A bool or tensor to mask the log probability. cond_indep_stack (tuple): The site's local stack of conditional independence metadata. Immutable. done (bool): A flag to indicate whether the message has been handled. stop (bool): A flag to stop further processing of the message. continuation (callable): A function to call after processing the message. infer (optional InferDict): A dictionary of inference parameters. obs (torch.Tensor): The observed value. log_prob (torch.Tensor): The log probability of the sample. log_prob_sum (torch.Tensor): The sum of the log probability. unscaled_log_prob (torch.Tensor): The unscaled log probability. score_parts (pyro.distributions.ScoreParts): A collection of score parts. packed (Message): A packed message, used during enumeration. """ type: str name: Optional[str] fn: Callable[_P, _T] is_observed: bool args: Tuple kwargs: Dict value: Optional[_T] scale: Union[torch.Tensor, float] mask: Union[bool, torch.Tensor, None] cond_indep_stack: Tuple["CondIndepStackFrame", ...] done: bool stop: bool continuation: Optional[Callable[["Message"], None]] infer: Optional[InferDict] obs: Optional[torch.Tensor] log_prob: torch.Tensor log_prob_sum: torch.Tensor unscaled_log_prob: torch.Tensor score_parts: "ScoreParts" packed: "Message" _intervener_id: Optional[str]
class _DimAllocator: """ Dimension allocator for internal use by :class:`plate`. There is a single global instance. Note that dimensions are indexed from the right, e.g. -1, -2. """ def __init__(self) -> None: # in reverse orientation of log_prob.shape self._stack: List[Optional[str]] = [] def allocate(self, name: str, dim: Optional[int]) -> int: """ Allocate a dimension to an :class:`plate` with given name. Dim should be either None for automatic allocation or a negative integer for manual allocation. """ if name in self._stack: raise ValueError(f"duplicate plate '{name}'") if dim is None: # Automatically designate the rightmost available dim for allocation. dim = -1 while -dim <= len(self._stack) and self._stack[-1 - dim] is not None: dim -= 1 elif dim >= 0: raise ValueError(f"Expected dim < 0 to index from the right, actual {dim}") # Allocate the requested dimension. while dim < -len(self._stack): self._stack.append(None) if self._stack[-1 - dim] is not None: raise ValueError( "\n".join( [ 'at plates "{}" and "{}", collide at dim={}'.format( name, self._stack[-1 - dim], dim ), "\nTry moving the dim of one plate to the left, e.g. dim={}".format( dim - 1 ), ] ) ) self._stack[-1 - dim] = name return dim def free(self, name: str, dim: int) -> None: """ Free a dimension. """ free_idx = -1 - dim # stack index to free assert self._stack[free_idx] == name self._stack[free_idx] = None while self._stack and self._stack[-1] is None: self._stack.pop() # Handles placement of plate dimensions _DIM_ALLOCATOR = _DimAllocator() class _EnumAllocator: """ Dimension allocator for internal use by :func:`~pyro.poutine.markov`. There is a single global instance. Note that dimensions are indexed from the right, e.g. -1, -2. Note that ids are simply nonnegative integers here. """ def set_first_available_dim(self, first_available_dim: int) -> None: """ Set the first available dim, which should be to the left of all :class:`plate` dimensions, e.g. ``-1 - max_plate_nesting``. This should be called once per program. In SVI this should be called only once per (guide,model) pair. """ assert first_available_dim < 0, first_available_dim self.next_available_dim = first_available_dim self.next_available_id = 0 self.dim_to_id: Dict[int, int] = {} # only the global ids def allocate(self, scope_dims: Optional[Set[int]] = None) -> Tuple[int, int]: """ Allocate a new recyclable dim and a unique id. If ``scope_dims`` is None, this allocates a global enumeration dim that will never be recycled. If ``scope_dims`` is specified, this allocates a local enumeration dim that can be reused by at any other local site whose scope excludes this site. :param set scope_dims: An optional set of (negative integer) local enumeration dims to avoid when allocating this dim. :return: A pair ``(dim, id)``, where ``dim`` is a negative integer and ``id`` is a nonnegative integer. :rtype: tuple """ id_ = self.next_available_id self.next_available_id += 1 dim = self.next_available_dim if dim == -float("inf"): raise ValueError( "max_plate_nesting must be set to a finite value for parallel enumeration" ) if scope_dims is None: # allocate a new global dimension self.next_available_dim -= 1 self.dim_to_id[dim] = id_ else: # allocate a new local dimension while dim in scope_dims: dim -= 1 return dim, id_ # Handles placement of enumeration dimensions _ENUM_ALLOCATOR = _EnumAllocator()
[docs]class NonlocalExit(Exception): """ Exception for exiting nonlocally from poutine execution. Used by poutine.EscapeMessenger to return site information. """ def __init__(self, site: Message, *args, **kwargs) -> None: """ :param site: message at a pyro site constructor. Just stores the input site. """ super().__init__(*args, **kwargs) = site
[docs] def reset_stack(self) -> None: """ Reset the state of the frames remaining in the stack. Necessary for multiple re-executions in poutine.queue. """ from pyro.poutine.block_messenger import BlockMessenger for frame in reversed(_PYRO_STACK): frame._reset() if isinstance(frame, BlockMessenger) and frame.hide_fn( break
[docs]def default_process_message(msg: Message) -> None: """ Default method for processing messages in inference. :param msg: a message to be processed :returns: None """ if msg["done"] or msg["is_observed"] or msg["value"] is not None: msg["done"] = True return msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"]) # after fn has been called, update msg to prevent it from being called again. msg["done"] = True
[docs]def apply_stack(initial_msg: Message) -> None: """ Execute the effect stack at a single site according to the following scheme: 1. For each ``Messenger`` in the stack from bottom to top, execute ``Messenger._process_message`` with the message; if the message field "stop" is True, stop; otherwise, continue 2. Apply default behavior (``default_process_message``) to finish remaining site execution 3. For each ``Messenger`` in the stack from top to bottom, execute ``_postprocess_message`` to update the message and internal messenger state with the site results 4. If the message field "continuation" is not ``None``, call it with the message :param dict initial_msg: the starting version of the trace site :returns: ``None`` """ stack = _PYRO_STACK # TODO check at runtime if stack is valid # msg is used to pass information up and down the stack msg = initial_msg pointer = 0 # go until time to stop? for frame in reversed(stack): pointer = pointer + 1 frame._process_message(msg) if msg["stop"]: break default_process_message(msg) for frame in stack[-pointer:]: frame._postprocess_message(msg) cont = msg["continuation"] if cont is not None: cont(msg)
[docs]def am_i_wrapped() -> bool: """ Checks whether the current computation is wrapped in a poutine. :returns: bool """ return len(_PYRO_STACK) > 0
@overload def effectful( fn: None = ..., type: Optional[str] = ... ) -> Callable[[Callable[_P, _T]], Callable[..., _T]]: ... @overload def effectful( fn: Callable[_P, _T] = ..., type: Optional[str] = ... ) -> Callable[..., _T]: ...
[docs]def effectful( fn: Optional[Callable[_P, _T]] = None, type: Optional[str] = None ) -> Callable: """ :param fn: function or callable that performs an effectful computation :param str type: the type label of the operation, e.g. `"sample"` Wrapper for calling :func:`~pyro.poutine.runtime.apply_stack` to apply any active effects. """ if fn is None: return functools.partial(effectful, type=type) if getattr(fn, "_is_effectful", None): return fn assert type is not None, f"must provide a type label for operation {fn}" assert type != "message", "cannot use 'message' as keyword" @functools.wraps(fn) def _fn( *args: _P.args, name: Optional[str] = None, infer: Optional[InferDict] = None, obs: Optional[_T] = None, **kwargs: _P.kwargs, ) -> _T: is_observed = obs is not None if not am_i_wrapped(): return fn(*args, **kwargs) else: msg = Message( type=type, name=name, fn=fn, is_observed=is_observed, args=args, kwargs=kwargs, value=obs, scale=1.0, mask=None, cond_indep_stack=(), done=False, stop=False, continuation=None, infer=infer if infer is not None else {}, ) # apply the stack and return its return value apply_stack(msg) if TYPE_CHECKING: assert msg["value"] is not None return msg["value"] _fn._is_effectful = True # type: ignore[attr-defined] return _fn
def _inspect() -> Message: """ EXPERIMENTAL Inspect the Pyro stack. .. warning:: The format of the returned message may change at any time and does not guarantee backwards compatibility. :returns: A message with all effects applied. :rtype: dict """ msg = Message( type="inspect", name="_pyro_inspect", fn=lambda: True, is_observed=False, args=(), kwargs={}, value=None, infer={"_do_not_trace": True}, scale=1.0, mask=None, cond_indep_stack=(), done=False, stop=False, continuation=None, ) apply_stack(msg) return msg
[docs]def get_mask() -> Union[bool, torch.Tensor, None]: """ Records the effects of enclosing ``poutine.mask`` handlers. This is useful for avoiding expensive ``pyro.factor()`` computations during prediction, when the log density need not be computed, e.g.:: def model(): # ... if poutine.get_mask() is not False: log_density = my_expensive_computation() pyro.factor("foo", log_density) # ... :returns: The mask. :rtype: None, bool, or torch.Tensor """ return _inspect()["mask"]
[docs]def get_plates() -> Tuple["CondIndepStackFrame", ...]: """ Records the effects of enclosing ``pyro.plate`` contexts. :returns: A tuple of :class:`pyro.poutine.indep_messenger.CondIndepStackFrame` objects. :rtype: tuple """ return _inspect()["cond_indep_stack"]