Source code for pyro.poutine.runtime

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import functools

from pyro.params.param_store import (  # noqa: F401
    _MODULE_NAMESPACE_DIVIDER,
    ParamStoreDict,
)

# the global pyro stack
_PYRO_STACK = []

# the global ParamStore
_PYRO_PARAM_STORE = ParamStoreDict()


class _DimAllocator:
    """
    Dimension allocator for internal use by :class:`plate`.
    There is a single global instance.

    Note that dimensions are indexed from the right, e.g. -1, -2.
    """
    def __init__(self):
        self._stack = []  # in reverse orientation of log_prob.shape

    def allocate(self, name, dim):
        """
        Allocate a dimension to an :class:`plate` with given name.
        Dim should be either None for automatic allocation or a negative
        integer for manual allocation.
        """
        if name in self._stack:
            raise ValueError('duplicate plate "{}"'.format(name))
        if dim is None:
            # Automatically designate the rightmost available dim for allocation.
            dim = -1
            while -dim <= len(self._stack) and self._stack[-1 - dim] is not None:
                dim -= 1
        elif dim >= 0:
            raise ValueError('Expected dim < 0 to index from the right, actual {}'.format(dim))

        # Allocate the requested dimension.
        while dim < -len(self._stack):
            self._stack.append(None)
        if self._stack[-1 - dim] is not None:
            raise ValueError('\n'.join([
                'at plates "{}" and "{}", collide at dim={}'.format(name, self._stack[-1 - dim], dim),
                '\nTry moving the dim of one plate to the left, e.g. dim={}'.format(dim - 1)]))
        self._stack[-1 - dim] = name
        return dim

    def free(self, name, dim):
        """
        Free a dimension.
        """
        free_idx = -1 - dim  # stack index to free
        assert self._stack[free_idx] == name
        self._stack[free_idx] = None
        while self._stack and self._stack[-1] is None:
            self._stack.pop()


# Handles placement of plate dimensions
_DIM_ALLOCATOR = _DimAllocator()


class _EnumAllocator:
    """
    Dimension allocator for internal use by :func:`~pyro.poutine.markov`.
    There is a single global instance.

    Note that dimensions are indexed from the right, e.g. -1, -2.
    Note that ids are simply nonnegative integers here.
    """
    def set_first_available_dim(self, first_available_dim):
        """
        Set the first available dim, which should be to the left of all
        :class:`plate` dimensions, e.g. ``-1 - max_plate_nesting``. This should
        be called once per program. In SVI this should be called only once per
        (guide,model) pair.
        """
        assert first_available_dim < 0, first_available_dim
        self.next_available_dim = first_available_dim
        self.next_available_id = 0
        self.dim_to_id = {}  # only the global ids

    def allocate(self, scope_dims=None):
        """
        Allocate a new recyclable dim and a unique id.

        If ``scope_dims`` is None, this allocates a global enumeration dim
        that will never be recycled. If ``scope_dims`` is specified, this
        allocates a local enumeration dim that can be reused by at any other
        local site whose scope excludes this site.

        :param set scope_dims: An optional set of (negative integer)
            local enumeration dims to avoid when allocating this dim.
        :return: A pair ``(dim, id)``, where ``dim`` is a negative integer
            and ``id`` is a nonnegative integer.
        :rtype: tuple
        """
        id_ = self.next_available_id
        self.next_available_id += 1

        dim = self.next_available_dim
        if dim == -float('inf'):
            raise ValueError("max_plate_nesting must be set to a finite value for parallel enumeration")
        if scope_dims is None:
            # allocate a new global dimension
            self.next_available_dim -= 1
            self.dim_to_id[dim] = id_
        else:
            # allocate a new local dimension
            while dim in scope_dims:
                dim -= 1

        return dim, id_


# Handles placement of enumeration dimensions
_ENUM_ALLOCATOR = _EnumAllocator()


[docs]class NonlocalExit(Exception): """ Exception for exiting nonlocally from poutine execution. Used by poutine.EscapeMessenger to return site information. """ def __init__(self, site, *args, **kwargs): """ :param site: message at a pyro site constructor. Just stores the input site. """ super().__init__(*args, **kwargs) self.site = site
[docs] def reset_stack(self): """ Reset the state of the frames remaining in the stack. Necessary for multiple re-executions in poutine.queue. """ for frame in reversed(_PYRO_STACK): frame._reset() if type(frame).__name__ == "BlockMessenger" and frame.hide_fn(self.site): break
[docs]def default_process_message(msg): """ Default method for processing messages in inference. :param msg: a message to be processed :returns: None """ if msg["done"] or msg["is_observed"] or msg["value"] is not None: msg["done"] = True return msg msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"]) # after fn has been called, update msg to prevent it from being called again. msg["done"] = True
[docs]def apply_stack(initial_msg): """ Execute the effect stack at a single site according to the following scheme: 1. For each ``Messenger`` in the stack from bottom to top, execute ``Messenger._process_message`` with the message; if the message field "stop" is True, stop; otherwise, continue 2. Apply default behavior (``default_process_message``) to finish remaining site execution 3. For each ``Messenger`` in the stack from top to bottom, execute ``_postprocess_message`` to update the message and internal messenger state with the site results 4. If the message field "continuation" is not ``None``, call it with the message :param dict initial_msg: the starting version of the trace site :returns: ``None`` """ stack = _PYRO_STACK # TODO check at runtime if stack is valid # msg is used to pass information up and down the stack msg = initial_msg pointer = 0 # go until time to stop? for frame in reversed(stack): pointer = pointer + 1 frame._process_message(msg) if msg["stop"]: break default_process_message(msg) for frame in stack[-pointer:]: frame._postprocess_message(msg) cont = msg["continuation"] if cont is not None: cont(msg) return None
[docs]def am_i_wrapped(): """ Checks whether the current computation is wrapped in a poutine. :returns: bool """ return len(_PYRO_STACK) > 0
[docs]def effectful(fn=None, type=None): """ :param fn: function or callable that performs an effectful computation :param str type: the type label of the operation, e.g. `"sample"` Wrapper for calling :func:`~pyro.poutine.runtime.apply_stack` to apply any active effects. """ if fn is None: return functools.partial(effectful, type=type) if getattr(fn, "_is_effectful", None): return fn assert type is not None, "must provide a type label for operation {}".format(fn) assert type != "message", "cannot use 'message' as keyword" @functools.wraps(fn) def _fn(*args, **kwargs): name = kwargs.pop("name", None) infer = kwargs.pop("infer", {}) value = kwargs.pop("obs", None) is_observed = value is not None if not am_i_wrapped(): return fn(*args, **kwargs) else: msg = { "type": type, "name": name, "fn": fn, "is_observed": is_observed, "args": args, "kwargs": kwargs, "value": value, "scale": 1.0, "mask": None, "cond_indep_stack": (), "done": False, "stop": False, "continuation": None, "infer": infer, } # apply the stack and return its return value apply_stack(msg) return msg["value"] _fn._is_effectful = True return _fn
def _inspect(): """ EXPERIMENTAL Inspect the Pyro stack. .. warning:: The format of the returned message may change at any time and does not guarantee backwards compatibility. :returns: A message with all effects applied. :rtype: dict """ msg = { "type": "inspect", "name": "_pyro_inspect", "fn": lambda: True, "is_observed": False, "args": (), "kwargs": {}, "value": None, "infer": {"_do_not_trace": True}, "scale": 1.0, "mask": None, "cond_indep_stack": (), "done": False, "stop": False, "continuation": None, } apply_stack(msg) return msg
[docs]def get_mask(): """ Records the effects of enclosing ``poutine.mask`` handlers. This is useful for avoiding expensive ``pyro.factor()`` computations during prediction, when the log density need not be computed, e.g.:: def model(): # ... if poutine.get_mask() is not False: log_density = my_expensive_computation() pyro.factor("foo", log_density) # ... :returns: The mask. :rtype: None, bool, or torch.Tensor """ return _inspect()["mask"]