Source code for pyro.poutine.handlers

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

"""
Poutine is a library of composable effect handlers for recording and modifying the
behavior of Pyro programs. These lower-level ingredients simplify the implementation
of new inference algorithms and behavior.

Handlers can be used as higher-order functions, decorators, or context managers
to modify the behavior of functions or blocks of code:

For example, consider the following Pyro program:

    >>> def model(x):
    ...     s = pyro.param("s", torch.tensor(0.5))
    ...     z = pyro.sample("z", dist.Normal(x, s))
    ...     return z ** 2

We can mark sample sites as observed using ``condition``,
which returns a callable with the same input and output signatures as ``model``:

    >>> conditioned_model = poutine.condition(model, data={"z": 1.0})

We can also use handlers as decorators:

    >>> @pyro.condition(data={"z": 1.0})
    ... def model(x):
    ...     s = pyro.param("s", torch.tensor(0.5))
    ...     z = pyro.sample("z", dist.Normal(x, s))
    ...     return z ** 2

Or as context managers:

    >>> with pyro.condition(data={"z": 1.0}):
    ...     s = pyro.param("s", torch.tensor(0.5))
    ...     z = pyro.sample("z", dist.Normal(0., s))
    ...     y = z ** 2

Handlers compose freely:

    >>> conditioned_model = poutine.condition(model, data={"z": 1.0})
    >>> traced_model = poutine.trace(conditioned_model)

Many inference algorithms or algorithmic components can be implemented
in just a few lines of code::

    guide_tr = poutine.trace(guide).get_trace(...)
    model_tr = poutine.trace(poutine.replay(conditioned_model, trace=guide_tr)).get_trace(...)
    monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
"""

import collections
import functools
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Literal,
    Optional,
    TypeVar,
    Union,
    overload,
)

from typing_extensions import ParamSpec

from pyro.poutine import util
from pyro.poutine.block_messenger import BlockMessenger
from pyro.poutine.broadcast_messenger import BroadcastMessenger
from pyro.poutine.collapse_messenger import CollapseMessenger
from pyro.poutine.condition_messenger import ConditionMessenger
from pyro.poutine.do_messenger import DoMessenger
from pyro.poutine.enum_messenger import EnumMessenger
from pyro.poutine.escape_messenger import EscapeMessenger
from pyro.poutine.infer_config_messenger import InferConfigMessenger
from pyro.poutine.lift_messenger import LiftMessenger
from pyro.poutine.markov_messenger import MarkovMessenger
from pyro.poutine.mask_messenger import MaskMessenger
from pyro.poutine.reparam_messenger import ReparamMessenger
from pyro.poutine.replay_messenger import ReplayMessenger
from pyro.poutine.runtime import NonlocalExit
from pyro.poutine.scale_messenger import ScaleMessenger
from pyro.poutine.seed_messenger import SeedMessenger
from pyro.poutine.substitute_messenger import SubstituteMessenger
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.uncondition_messenger import UnconditionMessenger

if TYPE_CHECKING:
    import numbers

    import torch

    from pyro.distributions.distribution import Distribution
    from pyro.infer.reparam.reparam import Reparam
    from pyro.poutine.runtime import InferDict, Message
    from pyro.poutine.trace_struct import Trace

_P = ParamSpec("_P")
_T = TypeVar("_T")

############################################
# Begin primitive operations
############################################


def _make_handler(msngr_cls, module=None):
    def handler_decorator(func):
        @functools.wraps(func)
        def handler(fn=None, *args, **kwargs):
            if fn is not None and not (
                callable(fn) or isinstance(fn, collections.abc.Iterable)
            ):
                raise ValueError(
                    f"{fn} is not callable, did you mean to pass it as a keyword arg?"
                )
            msngr = msngr_cls(*args, **kwargs)
            return (
                functools.update_wrapper(msngr(fn), fn, updated=())
                if fn is not None
                else msngr
            )

        handler.__doc__ = (
            """Convenient wrapper of :class:`~pyro.poutine.{}.{}` \n\n""".format(
                func.__name__ + "_messenger", msngr_cls.__name__
            )
            + (msngr_cls.__doc__ if msngr_cls.__doc__ else "")
        )
        if module is not None:
            handler.__module__ = module
        return handler

    return handler_decorator


@overload
def block(
    fn: None = ...,
    hide_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
    expose_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
    hide_all: bool = True,
    expose_all: bool = False,
    hide: Optional[List[str]] = None,
    expose: Optional[List[str]] = None,
    hide_types: Optional[List[str]] = None,
    expose_types: Optional[List[str]] = None,
) -> BlockMessenger: ...


@overload
def block(
    fn: Callable[_P, _T] = ...,
    hide_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
    expose_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
    hide_all: bool = True,
    expose_all: bool = False,
    hide: Optional[List[str]] = None,
    expose: Optional[List[str]] = None,
    hide_types: Optional[List[str]] = None,
    expose_types: Optional[List[str]] = None,
) -> Callable[_P, _T]: ...


[docs]@_make_handler(BlockMessenger) def block( # type: ignore[empty-body]
fn: Optional[Callable[_P, _T]] = None, hide_fn: Optional[Callable[["Message"], Optional[bool]]] = None, expose_fn: Optional[Callable[["Message"], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None, ) -> Union[BlockMessenger, Callable[_P, _T]]: ... @overload def broadcast( fn: None = ..., ) -> BroadcastMessenger: ... @overload def broadcast( fn: Callable[_P, _T] = ..., ) -> Callable[_P, _T]: ...
[docs]@_make_handler(BroadcastMessenger) def broadcast( # type: ignore[empty-body]
fn: Optional[Callable[_P, _T]] = None, ) -> Union[BroadcastMessenger, Callable[_P, _T]]: ... @overload def collapse( fn: None = ..., *args: Any, **kwargs: Any, ) -> CollapseMessenger: ... @overload def collapse( fn: Callable[_P, _T] = ..., *args: Any, **kwargs: Any, ) -> Callable[_P, _T]: ...
[docs]@_make_handler(CollapseMessenger) def collapse( # type: ignore[empty-body]
fn: Optional[Callable[_P, _T]] = None, *args: Any, **kwargs: Any, ) -> Union[CollapseMessenger, Callable[_P, _T]]: ... @overload def condition( data: Union[Dict[str, "torch.Tensor"], "Trace"], ) -> ConditionMessenger: ... @overload def condition( fn: Callable[_P, _T], data: Union[Dict[str, "torch.Tensor"], "Trace"], ) -> Callable[_P, _T]: ...
[docs]@_make_handler(ConditionMessenger) def condition( # type: ignore[empty-body]
fn: Callable[_P, _T], data: Union[Dict[str, "torch.Tensor"], "Trace"], ) -> Union[ConditionMessenger, Callable[_P, _T]]: ... @overload def do( data: Dict[str, Union["torch.Tensor", "numbers.Number"]], ) -> DoMessenger: ... @overload def do( fn: Callable[_P, _T], data: Dict[str, Union["torch.Tensor", "numbers.Number"]], ) -> Callable[_P, _T]: ...
[docs]@_make_handler(DoMessenger) def do( # type: ignore[empty-body]
fn: Callable[_P, _T], data: Dict[str, Union["torch.Tensor", "numbers.Number"]], ) -> Union[DoMessenger, Callable[_P, _T]]: ... @overload def enum( fn: None = ..., first_available_dim: Optional[int] = None, ) -> EnumMessenger: ... @overload def enum( fn: Callable[_P, _T] = ..., first_available_dim: Optional[int] = None, ) -> Callable[_P, _T]: ...
[docs]@_make_handler(EnumMessenger) def enum( # type: ignore[empty-body]
fn: Optional[Callable[_P, _T]] = None, first_available_dim: Optional[int] = None, ) -> Union[EnumMessenger, Callable[_P, _T]]: ... @overload def escape( escape_fn: Callable[["Message"], bool], ) -> EscapeMessenger: ... @overload def escape( fn: Callable[_P, _T], escape_fn: Callable[["Message"], bool], ) -> Callable[_P, _T]: ...
[docs]@_make_handler(EscapeMessenger) def escape( # type: ignore[empty-body]
fn: Callable[_P, _T], escape_fn: Callable[["Message"], bool], ) -> Union[EscapeMessenger, Callable[_P, _T]]: ... @overload def infer_config( config_fn: Callable[["Message"], "InferDict"], ) -> InferConfigMessenger: ... @overload def infer_config( fn: Callable[_P, _T], config_fn: Callable[["Message"], "InferDict"], ) -> Callable[_P, _T]: ...
[docs]@_make_handler(InferConfigMessenger) def infer_config( # type: ignore[empty-body]
fn: Callable[_P, _T], config_fn: Callable[["Message"], "InferDict"], ) -> Union[InferConfigMessenger, Callable[_P, _T]]: ... @overload def lift( prior: Union[Callable, "Distribution", Dict[str, Union["Distribution", Callable]]], ) -> LiftMessenger: ... @overload def lift( fn: Callable[_P, _T], prior: Union[Callable, "Distribution", Dict[str, Union["Distribution", Callable]]], ) -> Callable[_P, _T]: ...
[docs]@_make_handler(LiftMessenger) def lift( # type: ignore[empty-body]
fn: Callable[_P, _T], prior: Union[Callable, "Distribution", Dict[str, Union["Distribution", Callable]]], ) -> Union[LiftMessenger, Callable[_P, _T]]: ... @overload def mask( mask: Union[bool, "torch.BoolTensor"], ) -> MaskMessenger: ... @overload def mask( fn: Callable[_P, _T], mask: Union[bool, "torch.BoolTensor"], ) -> Callable[_P, _T]: ...
[docs]@_make_handler(MaskMessenger) def mask( # type: ignore[empty-body]
fn: Callable[_P, _T], mask: Union[bool, "torch.BoolTensor"], ) -> Union[MaskMessenger, Callable[_P, _T]]: ... @overload def reparam( config: Union[Dict[str, "Reparam"], Callable[["Message"], Optional["Reparam"]]], ) -> ReparamMessenger: ... @overload def reparam( fn: Callable[_P, _T], config: Union[Dict[str, "Reparam"], Callable[["Message"], Optional["Reparam"]]], ) -> Callable[_P, _T]: ...
[docs]@_make_handler(ReparamMessenger) def reparam( # type: ignore[empty-body]
fn: Callable[_P, _T], config: Union[Dict[str, "Reparam"], Callable[["Message"], Optional["Reparam"]]], ) -> Union[ReparamMessenger, Callable[_P, _T]]: ... @overload def replay( fn: None = ..., trace: Optional["Trace"] = None, params: Optional[Dict[str, "torch.Tensor"]] = None, ) -> ReplayMessenger: ... @overload def replay( fn: Callable[_P, _T] = ..., trace: Optional["Trace"] = None, params: Optional[Dict[str, "torch.Tensor"]] = None, ) -> Callable[_P, _T]: ...
[docs]@_make_handler(ReplayMessenger) def replay( # type: ignore[empty-body]
fn: Optional[Callable[_P, _T]] = None, trace: Optional["Trace"] = None, params: Optional[Dict[str, "torch.Tensor"]] = None, ) -> Union[ReplayMessenger, Callable[_P, _T]]: ... @overload def scale( scale: Union[float, "torch.Tensor"], ) -> ScaleMessenger: ... @overload def scale( fn: Callable[_P, _T], scale: Union[float, "torch.Tensor"], ) -> Callable[_P, _T]: ...
[docs]@_make_handler(ScaleMessenger) def scale( # type: ignore[empty-body]
fn: Callable[_P, _T], scale: Union[float, "torch.Tensor"], ) -> Union[ScaleMessenger, Callable[_P, _T]]: ... @overload def seed( rng_seed: int, ) -> SeedMessenger: ... @overload def seed( fn: Callable[_P, _T], rng_seed: int, ) -> Callable[_P, _T]: ...
[docs]@_make_handler(SeedMessenger) def seed( # type: ignore[empty-body]
fn: Callable[_P, _T], rng_seed: int, ) -> Union[SeedMessenger, Callable[_P, _T]]: ... @overload def substitute( data: Dict[str, "torch.Tensor"], ) -> SubstituteMessenger: ... @overload def substitute( fn: Callable[_P, _T], data: Dict[str, "torch.Tensor"], ) -> Callable[_P, _T]: ...
[docs]@_make_handler(SubstituteMessenger) def substitute( # type: ignore[empty-body]
fn: Callable[_P, _T], data: Dict[str, "torch.Tensor"], ) -> Union[SubstituteMessenger, Callable[_P, _T]]: ... @overload def trace( fn: None = ..., graph_type: Optional[Literal["flat", "dense"]] = None, param_only: Optional[bool] = None, ) -> TraceMessenger: ... @overload def trace( fn: Callable[_P, _T] = ..., graph_type: Optional[Literal["flat", "dense"]] = None, param_only: Optional[bool] = None, ) -> Callable[_P, _T]: ...
[docs]@_make_handler(TraceMessenger) def trace( # type: ignore[empty-body]
fn: Optional[Callable[_P, _T]] = None, graph_type: Optional[Literal["flat", "dense"]] = None, param_only: Optional[bool] = None, ) -> Union[TraceMessenger, Callable[_P, _T]]: ... @overload def uncondition( fn: None = ..., ) -> UnconditionMessenger: ... @overload def uncondition( fn: Callable[_P, _T] = ..., ) -> Callable[_P, _T]: ...
[docs]@_make_handler(UnconditionMessenger) def uncondition( # type: ignore[empty-body]
fn: Optional[Callable[_P, _T]] = None, ) -> Union[UnconditionMessenger, Callable[_P, _T]]: ... ######################################### # Begin composite operations #########################################
[docs]def queue( fn=None, queue=None, max_tries=None, extend_fn=None, escape_fn=None, num_samples=None, ): """ Used in sequential enumeration over discrete variables. Given a stochastic function and a queue, return a return value from a complete trace in the queue. :param fn: a stochastic function (callable containing Pyro primitive calls) :param queue: a queue data structure like multiprocessing.Queue to hold partial traces :param max_tries: maximum number of attempts to compute a single complete trace :param extend_fn: function (possibly stochastic) that takes a partial trace and a site, and returns a list of extended traces :param escape_fn: function (possibly stochastic) that takes a partial trace and a site, and returns a boolean value to decide whether to exit :param num_samples: optional number of extended traces for extend_fn to return :returns: stochastic function decorated with poutine logic """ if max_tries is None: max_tries = int(1e6) if extend_fn is None: extend_fn = util.enum_extend if escape_fn is None: escape_fn = util.discrete_escape if num_samples is None: num_samples = -1 def wrapper(wrapped): def _fn(*args, **kwargs): for i in range(max_tries): assert ( not queue.empty() ), "trying to get() from an empty queue will deadlock" next_trace = queue.get() try: ftr = trace( escape( replay(wrapped, trace=next_trace), # noqa: F821 escape_fn=functools.partial(escape_fn, next_trace), ) ) return ftr(*args, **kwargs) except NonlocalExit as site_container: site_container.reset_stack() for tr in extend_fn( ftr.trace.copy(), site_container.site, num_samples=num_samples ): queue.put(tr) raise ValueError("max tries ({}) exceeded".format(str(max_tries))) return _fn return wrapper(fn) if fn is not None else wrapper
@overload def markov( fn: None = ..., history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None, ) -> MarkovMessenger: ... @overload def markov( fn: Iterable[int] = ..., history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None, ) -> MarkovMessenger: ... @overload def markov( fn: Callable[_P, _T] = ..., history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None, ) -> Callable[_P, _T]: ...
[docs]def markov( fn: Optional[Union[Iterable[int], Callable[_P, _T]]] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None, ) -> Union[MarkovMessenger, Callable[_P, _T]]: """ Markov dependency declaration. This can be used in a variety of ways: - as a context manager - as a decorator for recursive functions - as an iterator for markov chains :param int history: The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to :class:`pyro.plate`. :param bool keep: If true, frames are replayable. This is important when branching: if ``keep=True``, neighboring branches at the same level can depend on each other; if ``keep=False``, neighboring branches are independent (conditioned on their share" :param int dim: An optional dimension to use for this independence index. Interface stub, behavior not yet implemented. :param str name: An optional unique name to help inference algorithms match :func:`pyro.markov` sites between models and guides. Interface stub, behavior not yet implemented. """ if fn is None: # Used as a decorator with bound args return MarkovMessenger(history=history, keep=keep, dim=dim, name=name) if not callable(fn): # Used as a generator return MarkovMessenger( history=history, keep=keep, dim=dim, name=name ).generator(iterable=fn) # Used as a decorator with bound args return MarkovMessenger(history=history, keep=keep, dim=dim, name=name)(fn)