Source code for pyro.poutine.indep_messenger

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import numbers
from collections import namedtuple

import torch

from pyro.util import ignore_jit_warnings
from .messenger import Messenger
from .runtime import _DIM_ALLOCATOR


[docs]class CondIndepStackFrame(namedtuple("CondIndepStackFrame", ["name", "dim", "size", "counter"])): @property def vectorized(self): return self.dim is not None def _key(self): with ignore_jit_warnings(["Converting a tensor to a Python number"]): size = self.size.item() if isinstance(self.size, torch.Tensor) else self.size return self.name, self.dim, size, self.counter def __eq__(self, other): return type(self) == type(other) and self._key() == other._key() def __ne__(self, other): return not self.__eq__(other) def __hash__(self): return hash(self._key()) def __str__(self): return self.name
[docs]class IndepMessenger(Messenger): """ This messenger keeps track of stack of independence information declared by nested ``plate`` contexts. This information is stored in a ``cond_indep_stack`` at each sample/observe site for consumption by ``TraceMessenger``. Example:: x_axis = IndepMessenger('outer', 320, dim=-1) y_axis = IndepMessenger('inner', 200, dim=-2) with x_axis: x_noise = sample("x_noise", dist.Normal(loc, scale).expand_by([320])) with y_axis: y_noise = sample("y_noise", dist.Normal(loc, scale).expand_by([200, 1])) with x_axis, y_axis: xy_noise = sample("xy_noise", dist.Normal(loc, scale).expand_by([200, 320])) """ def __init__(self, name=None, size=None, dim=None, device=None): if not torch._C._get_tracing_state() and size == 0: raise ZeroDivisionError("size cannot be zero") super(IndepMessenger, self).__init__() self._vectorized = None if dim is not None: self._vectorized = True self._indices = None self.name = name self.dim = dim self.size = size self.device = device self.counter = 0
[docs] def next_context(self): """ Increments the counter. """ self.counter += 1
def __enter__(self): if self._vectorized is not False: self._vectorized = True if self._vectorized is True: self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) return super(IndepMessenger, self).__enter__() def __exit__(self, *args): if self._vectorized is True: _DIM_ALLOCATOR.free(self.name, self.dim) return super(IndepMessenger, self).__exit__(*args) def __iter__(self): if self._vectorized is True or self.dim is not None: raise ValueError( "cannot use plate {} as both vectorized and non-vectorized" "independence context".format(self.name)) self._vectorized = False self.dim = None with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]): for i in self.indices: self.next_context() with self: yield i if isinstance(i, numbers.Number) else i.item() def _reset(self): if self._vectorized: _DIM_ALLOCATOR.free(self.name, self.dim) self._vectorized = None self.counter = 0 @property def indices(self): if self._indices is None: self._indices = torch.arange(self.size, dtype=torch.long).to(self.device) return self._indices def _process_message(self, msg): frame = CondIndepStackFrame(self.name, self.dim, self.size, self.counter) msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]