Source code for pyro.poutine.plate_messenger

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager

from .broadcast_messenger import BroadcastMessenger
from .messenger import block_messengers
from .subsample_messenger import SubsampleMessenger


[docs]class PlateMessenger(SubsampleMessenger): """ Swiss army knife of broadcasting amazingness: combines shape inference, independence annotation, and subsampling """ def _process_message(self, msg): super()._process_message(msg) return BroadcastMessenger._pyro_sample(msg) def __enter__(self): super().__enter__() if self._vectorized and self._indices is not None: return self.indices return None
[docs]@contextmanager def block_plate(name=None, dim=None, *, strict=True): """ EXPERIMENTAL Context manager to temporarily block a single enclosing plate. This is useful for sampling auxiliary variables or lazily sampling global variables that are needed in a plated context. For example the following models are equivalent: Example:: def model_1(data): loc = pyro.sample("loc", dist.Normal(0, 1)) with pyro.plate("data", len(data)): with block_plate("data"): scale = pyro.sample("scale", dist.LogNormal(0, 1)) pyro.sample("x", dist.Normal(loc, scale)) def model_2(data): loc = pyro.sample("loc", dist.Normal(0, 1)) scale = pyro.sample("scale", dist.LogNormal(0, 1)) with pyro.plate("data", len(data)): pyro.sample("x", dist.Normal(loc, scale)) :param str name: Optional name of plate to match. :param int dim: Optional dim of plate to match. Must be negative. :param bool strict: Whether to error if no matching plate is found. Defaults to True. :raises: ValueError if no enclosing plate was found and ``strict=True``. """ if (name is not None) == (dim is not None): raise ValueError("Exactly one of name,dim must be specified") if name is not None: assert isinstance(name, str) if dim is not None: assert isinstance(dim, int) assert dim < 0 def predicate(messenger): if not isinstance(messenger, PlateMessenger): return False if name is not None: return messenger.name == name if dim is not None: return messenger.dim == dim with block_messengers(predicate) as matches: if strict and len(matches) != 1: raise ValueError(f"block_plate matched {len(matches)} messengers. " "Try either removing the block_plate or " "setting strict=False.") yield