# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from collections import Counter, OrderedDict, namedtuple
from enum import Enum
[docs]class StackFrame:
"""
Consistent bidirectional mapping between integer positional dimensions and names.
Can be queried like a dictionary (``value = frame[key]``, ``frame[key] = value``).
"""
def __init__(self, name_to_dim, dim_to_name, history=1, keep=False):
assert isinstance(name_to_dim, OrderedDict) and all(
isinstance(name, str) and isinstance(dim, int)
for name, dim in name_to_dim.items()
)
assert isinstance(dim_to_name, OrderedDict) and all(
isinstance(name, str) and isinstance(dim, int)
for dim, name in dim_to_name.items()
)
self.name_to_dim = name_to_dim
self.dim_to_name = dim_to_name
self.history = history
self.keep = keep
def __setitem__(self, key, value):
assert (
isinstance(key, (int, str))
and isinstance(value, (int, str))
and type(key) != type(value)
)
name, dim = (value, key) if isinstance(key, int) else (key, value)
self.name_to_dim[name], self.dim_to_name[dim] = dim, name
def __getitem__(self, key):
assert isinstance(key, (int, str))
return self.dim_to_name[key] if isinstance(key, int) else self.name_to_dim[key]
def __delitem__(self, key):
assert isinstance(key, (int, str))
k2v, v2k = (
(self.dim_to_name, self.name_to_dim)
if isinstance(key, int)
else (self.name_to_dim, self.dim_to_name)
)
del v2k[k2v[key]]
del k2v[key]
def __contains__(self, key):
assert isinstance(key, (int, str))
return key in (self.dim_to_name if isinstance(key, int) else self.name_to_dim)
[docs]class DimType(Enum):
"""Enumerates the possible types of dimensions to allocate"""
LOCAL = 0
GLOBAL = 1
VISIBLE = 2
DimRequest = namedtuple("DimRequest", ["value", "dim_type"])
DimRequest.__new__.__defaults__ = (None, DimType.LOCAL)
[docs]class DimStack:
"""
Single piece of global state to keep track of the mapping between names and dimensions.
Replaces the plate :class:`~pyro.poutine.runtime._DimAllocator`,
the enum :class:`~pyro.poutine.runtime._EnumAllocator`, the ``stack`` in :class:`~MarkovMessenger`,
``_param_dims`` and ``_value_dims`` in :class:`~EnumMessenger`, and ``dim_to_symbol`` in ``msg['infer']``
"""
def __init__(self):
global_frame = StackFrame(
name_to_dim=OrderedDict(),
dim_to_name=OrderedDict(),
history=0,
keep=False,
)
self._local_stack = [global_frame]
self._iter_stack = [global_frame]
self._global_stack = [global_frame]
self._first_available_dim = self.DEFAULT_FIRST_DIM
self.outermost = None
MAX_DIM = -25
DEFAULT_FIRST_DIM = -5
[docs] def set_first_available_dim(self, dim):
assert dim is None or (self.MAX_DIM < dim < 0)
old_dim, self._first_available_dim = self._first_available_dim, dim
return old_dim
[docs] def push_global(self, frame):
self._global_stack.append(frame)
[docs] def pop_global(self):
assert self._global_stack, "cannot pop the global frame"
return self._global_stack.pop()
[docs] def push_iter(self, frame):
self._iter_stack.append(frame)
[docs] def pop_iter(self):
assert self._iter_stack, "cannot pop the global frame"
return self._iter_stack.pop()
[docs] def push_local(self, frame):
self._local_stack.append(frame)
[docs] def pop_local(self):
assert self._local_stack, "cannot pop the global frame"
return self._local_stack.pop()
@property
def global_frame(self):
return self._global_stack[-1]
@property
def local_frame(self):
return self._local_stack[-1]
@property
def current_write_env(self):
return (
self._local_stack[-1:]
if not self.local_frame.keep
else self._local_stack[-self.local_frame.history - 1 :]
)
@property
def current_read_env(self):
"""
Collect all frames necessary to compute the full name <--> dim mapping
and interpret Funsor inputs or batch shapes at any point in a computation.
"""
return (
self._global_stack
+ self._local_stack[-self.local_frame.history - 1 :]
+ self._iter_stack
)
def _genvalue(self, key, value_request):
"""
Given proposed values for a fresh (name, dim) pair, computes a new, possibly
identical (name, dim) pair consistent with the current name <--> dim mapping.
This function is pure and does not update the name <--> dim mapping itself.
The implementation here is only one of several possibilities, and was chosen
to match the behavior of Pyro's old enumeration machinery as closely as possible.
"""
if isinstance(key, int):
dim, name = key, value_request.value
fresh_name = "_pyro_dim_{}".format(-key) if name is None else name
return dim, fresh_name
elif isinstance(key, str):
name, dim, dim_type = key, value_request.value, value_request.dim_type
if dim_type == DimType.VISIBLE:
fresh_dim = -1 if dim is None else dim
else:
fresh_dim = self._first_available_dim # discard input...
while any(fresh_dim in p for p in self.current_read_env):
fresh_dim -= 1
if fresh_dim < self.MAX_DIM or (
dim_type == DimType.VISIBLE and fresh_dim <= self._first_available_dim
):
raise ValueError(
"Ran out of free dims during allocation for {}".format(name)
)
return name, fresh_dim
raise ValueError(
"{} and {} not a valid name-dim pair".format(key, value_request)
)
[docs] def allocate(self, key_to_value_request):
# step 1: split into fresh and non-fresh
key_to_value = OrderedDict()
for key, value_request in tuple(key_to_value_request.items()):
value = value_request.value
for frame in self.current_read_env:
if value is None and key in frame:
key_to_value[key] = frame[key]
del key_to_value_request[key]
break
elif value is not None and value in frame:
key_to_value[key] = value
del key_to_value_request[key]
break
# step 2: check that the non-fresh input mapping from keys to values is 1-1
if max(Counter(key_to_value.values()).values(), default=0) > 1:
raise ValueError("{} is not a valid shape request".format(key_to_value))
# step 3: allocate fresh values for all fresh
for key, value_request in key_to_value_request.items():
key, fresh_value = self._genvalue(key, value_request)
# if this key is already active but inconsistent with the fresh value,
# generate a fresh_key for future conversions via _genvalue in reverse
if value_request.dim_type != DimType.VISIBLE or any(
key in frame for frame in self.current_read_env
):
_, fresh_key = self._genvalue(
fresh_value, DimRequest(key, value_request.dim_type)
)
else:
fresh_key = key
for frame in (
[self.global_frame]
if value_request.dim_type != DimType.LOCAL
else self.current_write_env
):
frame[fresh_key] = fresh_value
# use the user-provided key rather than fresh_key for satisfying this request only
key_to_value[key] = fresh_value
assert not any(isinstance(value, DimRequest) for value in key_to_value.values())
return key_to_value
[docs] def names_from_batch_shape(self, batch_shape, dim_type=DimType.LOCAL):
return self.allocate_dim_to_name(
OrderedDict(
(dim, DimRequest(None, dim_type))
for dim in range(-len(batch_shape), 0)
if batch_shape[dim] > 1
)
)
_DIM_STACK = DimStack() # only one global instance