# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import re
import weakref
from abc import ABCMeta, abstractmethod
from contextlib import ExitStack
import torch
from torch.distributions import biject_to
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import pyro.poutine.runtime as runtime
from pyro.distributions.util import broadcast_shape, sum_rightmost
from pyro.infer.autoguide.guides import prototype_hide_fn
from pyro.infer.autoguide.initialization import InitMessenger
from pyro.nn.module import PyroModule, PyroParam
class _EasyGuideMeta(type(PyroModule), ABCMeta):
pass
[docs]class EasyGuide(PyroModule, metaclass=_EasyGuideMeta):
"""
Base class for "easy guides", which are more flexible than
:class:`~pyro.infer.AutoGuide` s, but are easier to write than raw Pyro guides.
Derived classes should define a :meth:`guide` method. This :meth:`guide`
method can combine ordinary guide statements (e.g. ``pyro.sample`` and
``pyro.param``) with the following special statements:
- ``group = self.group(...)`` selects multiple ``pyro.sample`` sites in the
model. See :class:`Group` for subsequent methods.
- ``with self.plate(...): ...`` should be used instead of ``pyro.plate``.
- ``self.map_estimate(...)`` uses a ``Delta`` guide for a single site.
Derived classes may also override the :meth:`init` method to provide custom
initialization for models sites.
:param callable model: A Pyro model.
"""
def __init__(self, model):
super().__init__()
self._pyro_name = type(self).__name__
self._model = (model,)
self.prototype_trace = None
self.frames = {}
self.groups = {}
self.plates = {}
@property
def model(self):
return self._model[0]
def _setup_prototype(self, *args, **kwargs):
# run the model so we can inspect its structure
model = poutine.block(InitMessenger(self.init)(self.model), prototype_hide_fn)
self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
*args, **kwargs
)
for name, site in self.prototype_trace.iter_stochastic_nodes():
for frame in site["cond_indep_stack"]:
if not frame.vectorized:
raise NotImplementedError(
"EasyGuide does not support sequential pyro.plate"
)
self.frames[frame.name] = frame
[docs] @abstractmethod
def guide(self, *args, **kargs):
"""
Guide implementation, to be overridden by user.
"""
raise NotImplementedError
[docs] def init(self, site):
"""
Model initialization method, may be overridden by user.
This should input a site and output a valid sample from that site.
The default behavior is to draw a random sample::
return site["fn"]()
For other possible initialization functions see
http://docs.pyro.ai/en/stable/infer.autoguide.html#module-pyro.infer.autoguide.initialization
"""
return site["fn"]()
[docs] def forward(self, *args, **kwargs):
"""
Runs the guide. This is typically used by inference algorithms.
.. note:: This method is used internally by :class:`~torch.nn.Module`.
Users should instead use :meth:`~torch.nn.Module.__call__`.
"""
if self.prototype_trace is None:
self._setup_prototype(*args, **kwargs)
result = self.guide(*args, **kwargs)
self.plates.clear()
return result
[docs] def plate(
self, name, size=None, subsample_size=None, subsample=None, *args, **kwargs
):
"""
A wrapper around :class:`pyro.plate` to allow `EasyGuide` to
automatically construct plates. You should use this rather than
:class:`pyro.plate` inside your :meth:`guide` implementation.
"""
if name not in self.plates:
self.plates[name] = pyro.plate(
name, size, subsample_size, subsample, *args, **kwargs
)
return self.plates[name]
[docs] def group(self, match=".*"):
"""
Select a :class:`Group` of model sites for joint guidance.
:param str match: A regex string matching names of model sample sites.
:return: A group of model sites.
:rtype: Group
"""
if match not in self.groups:
sites = [
site
for name, site in self.prototype_trace.iter_stochastic_nodes()
if re.match(match, name)
]
if not sites:
raise ValueError(
"EasyGuide.group() pattern {} matched no model sites".format(
repr(match)
)
)
self.groups[match] = Group(self, sites)
return self.groups[match]
[docs] def map_estimate(self, name):
"""
Construct a maximum a posteriori (MAP) guide using Delta distributions.
:param str name: The name of a model sample site.
:return: A sampled value.
:rtype: torch.Tensor
"""
site = self.prototype_trace.nodes[name]
fn = site["fn"]
event_dim = fn.event_dim
init_needed = not hasattr(self, name)
if init_needed:
init_value = site["value"].detach()
with ExitStack() as stack:
for frame in site["cond_indep_stack"]:
plate = self.plate(frame.name)
if plate not in runtime._PYRO_STACK:
stack.enter_context(plate)
elif init_needed and plate.subsample_size < plate.size:
# Repeat the init_value to full size.
dim = plate.dim - event_dim
assert init_value.size(dim) == plate.subsample_size
ind = torch.arange(plate.size, device=init_value.device)
ind = ind % plate.subsample_size
init_value = init_value.index_select(dim, ind)
if init_needed:
setattr(self, name, PyroParam(init_value, fn.support, event_dim))
value = getattr(self, name)
return pyro.sample(name, dist.Delta(value, event_dim=event_dim))
[docs]class Group:
"""
An autoguide helper to match a group of model sites.
:ivar torch.Size event_shape: The total flattened concatenated shape of all
matching sample sites in the model.
:ivar list prototype_sites: A list of all matching sample sites in a
prototype trace of the model.
:param EasyGuide guide: An easyguide instance.
:param list sites: A list of model sites.
"""
def __init__(self, guide, sites):
assert isinstance(sites, list)
assert sites
self._guide = weakref.ref(guide)
self.prototype_sites = sites
self._site_sizes = {}
self._site_batch_shapes = {}
# A group is in a frame only if all its sample sites are in that frame.
# Thus a group can be subsampled only if all its sites can be subsampled.
self.common_frames = frozenset.intersection(
*(
frozenset(f for f in site["cond_indep_stack"] if f.vectorized)
for site in sites
)
)
rightmost_common_dim = -float("inf")
if self.common_frames:
rightmost_common_dim = max(f.dim for f in self.common_frames)
# Compute flattened concatenated event_shape and split batch_shape into
# a common batch_shape (which can change each SVI step due to
# subsampling) and site batch_shapes (which must remain constant size).
for site in sites:
site_event_numel = torch.Size(site["fn"].event_shape).numel()
site_batch_shape = list(site["fn"].batch_shape)
for f in self.common_frames:
# Consider this dim part of the common_batch_shape.
site_batch_shape[f.dim] = 1
while site_batch_shape and site_batch_shape[0] == 1:
site_batch_shape = site_batch_shape[1:]
if len(site_batch_shape) > -rightmost_common_dim:
raise ValueError(
"Group expects all per-site plates to be right of all common plates, "
"but found a per-site plate {} on left at site {}".format(
-len(site_batch_shape), repr(site["name"])
)
)
site_batch_shape = torch.Size(site_batch_shape)
self._site_batch_shapes[site["name"]] = site_batch_shape
self._site_sizes[site["name"]] = site_batch_shape.numel() * site_event_numel
self.event_shape = torch.Size([sum(self._site_sizes.values())])
def __getstate__(self):
state = getattr(super(), "__getstate__", self.__dict__.copy)()
state["_guide"] = state["_guide"]() # weakref -> ref
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._guide = weakref.ref(self._guide) # ref -> weakref
@property
def guide(self):
return self._guide()
[docs] def sample(self, guide_name, fn, infer=None):
"""
Wrapper around ``pyro.sample()`` to create a single auxiliary sample
site and then unpack to multiple sample sites for model replay.
:param str guide_name: The name of the auxiliary guide site.
:param callable fn: A distribution with shape ``self.event_shape``.
:param dict infer: Optional inference configuration dict.
:returns: A pair ``(guide_z, model_zs)`` where ``guide_z`` is the
single concatenated blob and ``model_zs`` is a dict mapping
site name to constrained model sample.
:rtype: tuple
"""
# Sample a packed tensor.
if fn.event_shape != self.event_shape:
raise ValueError(
"Invalid fn.event_shape for group: expected {}, actual {}".format(
tuple(self.event_shape), tuple(fn.event_shape)
)
)
if infer is None:
infer = {}
infer["is_auxiliary"] = True
guide_z = pyro.sample(guide_name, fn, infer=infer)
common_batch_shape = guide_z.shape[:-1]
model_zs = {}
pos = 0
for site in self.prototype_sites:
name = site["name"]
fn = site["fn"]
# Extract slice from packed sample.
size = self._site_sizes[name]
batch_shape = broadcast_shape(
common_batch_shape, self._site_batch_shapes[name]
)
unconstrained_z = guide_z[..., pos : pos + size]
unconstrained_z = unconstrained_z.reshape(batch_shape + fn.event_shape)
pos += size
# Transform to constrained space.
transform = biject_to(fn.support)
z = transform(unconstrained_z)
log_density = transform.inv.log_abs_det_jacobian(z, unconstrained_z)
log_density = sum_rightmost(
log_density, log_density.dim() - z.dim() + fn.event_dim
)
delta_dist = dist.Delta(z, log_density=log_density, event_dim=fn.event_dim)
# Replay model sample statement.
with ExitStack() as stack:
for frame in site["cond_indep_stack"]:
plate = self.guide.plate(frame.name)
if plate not in runtime._PYRO_STACK:
stack.enter_context(plate)
model_zs[name] = pyro.sample(name, delta_dist)
return guide_z, model_zs
[docs] def map_estimate(self):
"""
Construct a maximum a posteriori (MAP) guide using Delta distributions.
:return: A dict mapping model site name to sampled value.
:rtype: dict
"""
return {
site["name"]: self.guide.map_estimate(site["name"])
for site in self.prototype_sites
}
[docs]def easy_guide(model):
"""
Convenience decorator to create an :class:`EasyGuide` .
The following are equivalent::
# Version 1. Decorate a function.
@easy_guide(model)
def guide(self, foo, bar):
return my_guide(foo, bar)
# Version 2. Create and instantiate a subclass of EasyGuide.
class Guide(EasyGuide):
def guide(self, foo, bar):
return my_guide(foo, bar)
guide = Guide(model)
Note ``@easy_guide`` wrappers cannot be pickled; to build a guide that can
be pickled, instead subclass from :class:`EasyGuide`.
:param callable model: a Pyro model.
"""
def decorator(fn):
Guide = type(fn.__name__, (EasyGuide,), {"guide": fn})
return Guide(model)
return decorator