Source code for pyro.poutine.plate_messenger

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional

from pyro.poutine.broadcast_messenger import BroadcastMessenger
from pyro.poutine.messenger import Messenger, block_messengers
from pyro.poutine.subsample_messenger import SubsampleMessenger

if TYPE_CHECKING:
    import torch

    from pyro.poutine.runtime import Message


[docs]class PlateMessenger(SubsampleMessenger): """ Swiss army knife of broadcasting amazingness: combines shape inference, independence annotation, and subsampling """ def _process_message(self, msg: "Message") -> None: super()._process_message(msg) BroadcastMessenger._pyro_sample(msg) def __enter__(self) -> Optional["torch.Tensor"]: # type: ignore[override] super().__enter__() if self._vectorized and self._indices is not None: return self.indices return None
[docs]@contextmanager def block_plate( name: Optional[str] = None, dim: Optional[int] = None, *, strict: bool = True ) -> Iterator[None]: """ EXPERIMENTAL Context manager to temporarily block a single enclosing plate. This is useful for sampling auxiliary variables or lazily sampling global variables that are needed in a plated context. For example the following models are equivalent: Example:: def model_1(data): loc = pyro.sample("loc", dist.Normal(0, 1)) with pyro.plate("data", len(data)): with block_plate("data"): scale = pyro.sample("scale", dist.LogNormal(0, 1)) pyro.sample("x", dist.Normal(loc, scale)) def model_2(data): loc = pyro.sample("loc", dist.Normal(0, 1)) scale = pyro.sample("scale", dist.LogNormal(0, 1)) with pyro.plate("data", len(data)): pyro.sample("x", dist.Normal(loc, scale)) :param str name: Optional name of plate to match. :param int dim: Optional dim of plate to match. Must be negative. :param bool strict: Whether to error if no matching plate is found. Defaults to True. :raises: ValueError if no enclosing plate was found and ``strict=True``. """ if (name is not None) == (dim is not None): raise ValueError("Exactly one of name,dim must be specified") if name is not None: assert isinstance(name, str) if dim is not None: assert isinstance(dim, int) assert dim < 0 def predicate(messenger: Messenger) -> bool: if not isinstance(messenger, PlateMessenger): return False if name is not None: return messenger.name == name if dim is not None: return messenger.dim == dim raise ValueError("Unreachable") with block_messengers(predicate) as matches: if strict and len(matches) != 1: raise ValueError( f"block_plate matched {len(matches)} messengers. " "Try either removing the block_plate or " "setting strict=False." ) yield