Source code for pyro.poutine.markov_messenger

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from collections import Counter
from contextlib import ExitStack
from typing import TYPE_CHECKING, Iterable, Iterator, List, Optional, Set

from typing_extensions import Self

from pyro.poutine.reentrant_messenger import ReentrantMessenger

if TYPE_CHECKING:
    from pyro.poutine.runtime import Message


[docs]class MarkovMessenger(ReentrantMessenger): """ Markov dependency declaration. This is a statistical equivalent of a memory management arena. :param int history: The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to :class:`pyro.plate`. :param bool keep: If true, frames are replayable. This is important when branching: if ``keep=True``, neighboring branches at the same level can depend on each other; if ``keep=False``, neighboring branches are independent (conditioned on their shared ancestors). :param int dim: An optional dimension to use for this independence index. Interface stub, behavior not yet implemented. :param str name: An optional unique name to help inference algorithms match :func:`pyro.markov` sites between models and guides. Interface stub, behavior not yet implemented. """ def __init__( self, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None, ) -> None: assert history >= 0 self.history = history self.keep = keep self.dim = dim self.name = name if dim is not None: raise NotImplementedError( "vectorized markov not yet implemented, try setting dim to None" ) if name is not None: raise NotImplementedError( "vectorized markov not yet implemented, try setting name to None" ) self._iterable: Optional[Iterable[int]] = None self._pos = -1 self._stack: List[Set[str]] = [] super().__init__()
[docs] def generator(self, iterable: Iterable[int]) -> Self: self._iterable = iterable return self
def __iter__(self) -> Iterator[int]: with ExitStack() as stack: assert self._iterable is not None for value in self._iterable: stack.enter_context(self) yield value def __enter__(self) -> Self: self._pos += 1 if len(self._stack) <= self._pos: self._stack.append(set()) return super().__enter__() def __exit__(self, *args, **kwargs) -> None: if not self.keep: self._stack.pop() self._pos -= 1 return super().__exit__(*args, **kwargs) def _pyro_sample(self, msg: "Message") -> None: if msg["done"] or type(msg["fn"]).__name__ == "_Subsample": return # We use a Counter rather than a set here so that sites can correctly # go out of scope when any one of their markov contexts exits. # This accounting can be done by users of these fields, # e.g. EnumMessenger. assert msg["name"] is not None assert msg["infer"] is not None infer = msg["infer"] scope = infer.setdefault( "_markov_scope", Counter() ) # site name -> markov depth for pos in range(max(0, self._pos - self.history), self._pos + 1): scope.update(self._stack[pos]) infer["_markov_depth"] = 1 + infer.get("_markov_depth", 0) self._stack[self._pos].add(msg["name"])