Source code for pyro.contrib.easyguide.easyguide

import re
import weakref
from abc import ABCMeta, abstractmethod

import torch
from contextlib import ExitStack
from torch.distributions import biject_to

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import pyro.poutine.runtime as runtime
from pyro.infer.autoguide.initialization import InitMessenger
from pyro.distributions.util import broadcast_shape, sum_rightmost
from pyro.poutine.util import prune_subsample_sites


[docs]class EasyGuide(object, metaclass=ABCMeta): """ Base class for "easy guides". Derived classes should define a :meth:`guide` method. This :meth:`guide` method can combine ordinary guide statements (e.g. ``pyro.sample`` and ``pyro.param``) with the following special statements: - ``group = self.group(...)`` selects multiple ``pyro.sample`` sites in the model. See :class:`Group` for subsequent methods. - ``with self.plate(...): ...`` should be used instead of ``pyro.plate``. - ``self.map_estimate(...)`` uses a ``Delta`` guide for a single site. Derived classes may also override the :meth:`init` method to provide custom initialization for models sites. :param callable model: A Pyro model. """ def __init__(self, model): self.model = model self.prototype_trace = None self.frames = {} self.groups = {} self.plates = {} def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = InitMessenger(self.init)(self.model) self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs) self.prototype_trace = prune_subsample_sites(self.prototype_trace) for name, site in self.prototype_trace.iter_stochastic_nodes(): for frame in site["cond_indep_stack"]: if not frame.vectorized: raise NotImplementedError("EasyGuide does not support sequential pyro.plate") self.frames[frame.name] = frame
[docs] @abstractmethod def guide(self, *args, **kargs): """ Guide implementation, to be overridden by user. """ raise NotImplementedError
[docs] def init(self, site): """ Model initialization method, may be overridden by user. This should input a site and output a valid sample from that site. The default behavior is to draw a random sample:: return site["fn"]() For other possible initialization functions see http://docs.pyro.ai/en/stable/infer.autoguide.html#module-pyro.infer.autoguide.initialization """ return site["fn"]()
[docs] def __call__(self, *args, **kwargs): """ Runs the guide. This is typically used by inference algorithms. """ if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) result = self.guide(*args, **kwargs) self.plates.clear() return result
[docs] def plate(self, name, size=None, subsample_size=None, subsample=None, *args, **kwargs): """ A wrapper around :class:`pyro.plate` to allow `EasyGuide` to automatically construct plates. You should use this rather than :class:`pyro.plate` inside your :meth:`guide` implementation. """ if name not in self.plates: self.plates[name] = pyro.plate(name, size, subsample_size, subsample, *args, **kwargs) return self.plates[name]
[docs] def group(self, match=".*"): """ Select a :class:`Group` of model sites for joint guidance. :param str match: A regex string matching names of model sample sites. :return: A group of model sites. :rtype: Group """ if match not in self.groups: sites = [site for name, site in self.prototype_trace.iter_stochastic_nodes() if re.match(match, name)] if not sites: raise ValueError("EasyGuide.group() pattern {} matched no model sites" .format(repr(match))) self.groups[match] = Group(self, sites) return self.groups[match]
[docs] def map_estimate(self, name): """ Construct a maximum a posteriori (MAP) guide using Delta distributions. :param str name: The name of a model sample site. :return: A sampled value. :rtype: torch.Tensor """ site = self.prototype_trace.nodes[name] fn = site["fn"] param_name = "auto_{}".format(name) init_value = None if param_name not in pyro.get_param_store(): init_value = site["value"].detach() with ExitStack() as stack: for frame in site["cond_indep_stack"]: plate = self.plate(frame.name) if plate not in runtime._PYRO_STACK: stack.enter_context(plate) elif init_value is not None and plate.subsample_size < plate.size: # Repeat the init_value to full size. dim = plate.dim - fn.event_dim assert init_value.size(dim) == plate.subsample_size ind = torch.arange(plate.size, device=init_value.device) ind = ind % plate.subsample_size init_value = init_value.index_select(dim, ind) value = pyro.param(param_name, init_value, constraint=fn.support, event_dim=fn.event_dim) return pyro.sample(name, dist.Delta(value, event_dim=fn.event_dim))
[docs]class Group(object): """ An autoguide helper to match a group of model sites. :ivar torch.Size event_shape: The total flattened concatenated shape of all matching sample sites in the model. :ivar list prototype_sites: A list of all matching sample sites in a prototype trace of the model. :param EasyGuide guide: An easyguide instance. :param list sites: A list of model sites. """ def __init__(self, guide, sites): assert isinstance(sites, list) assert sites self._guide = weakref.ref(guide) self.prototype_sites = sites self._site_sizes = {} self._site_batch_shapes = {} # A group is in a frame only if all its sample sites are in that frame. # Thus a group can be subsampled only if all its sites can be subsampled. self.common_frames = frozenset.intersection(*( frozenset(f for f in site["cond_indep_stack"] if f.vectorized) for site in sites)) rightmost_common_dim = -float('inf') if self.common_frames: rightmost_common_dim = max(f.dim for f in self.common_frames) # Compute flattened concatenated event_shape and split batch_shape into # a common batch_shape (which can change each SVI step due to # subsampling) and site batch_shapes (which must remain constant size). for site in sites: site_event_numel = torch.Size(site["fn"].event_shape).numel() site_batch_shape = list(site["fn"].batch_shape) for f in self.common_frames: # Consider this dim part of the common_batch_shape. site_batch_shape[f.dim] = 1 while site_batch_shape and site_batch_shape[0] == 1: site_batch_shape = site_batch_shape[1:] if len(site_batch_shape) > -rightmost_common_dim: raise ValueError( "Group expects all per-site plates to be right of all common plates, " "but found a per-site plate {} on left at site {}" .format(-len(site_batch_shape), repr(site["name"]))) site_batch_shape = torch.Size(site_batch_shape) self._site_batch_shapes[site["name"]] = site_batch_shape self._site_sizes[site["name"]] = site_batch_shape.numel() * site_event_numel self.event_shape = torch.Size([sum(self._site_sizes.values())]) @property def guide(self): return self._guide()
[docs] def sample(self, guide_name, fn, infer=None): """ Wrapper around ``pyro.sample()`` to create a single auxiliary sample site and then unpack to multiple sample sites for model replay. :param str guide_name: The name of the auxiliary guide site. :param callable fn: A distribution with shape ``self.event_shape``. :param dict infer: Optional inference configuration dict. :returns: A pair ``(guide_z, model_zs)`` where ``guide_z`` is the single concatenated blob and ``model_zs`` is a dict mapping site name to constrained model sample. :rtype: tuple """ # Sample a packed tensor. if fn.event_shape != self.event_shape: raise ValueError("Invalid fn.event_shape for group: expected {}, actual {}" .format(tuple(self.event_shape), tuple(fn.event_shape))) if infer is None: infer = {} infer["is_auxiliary"] = True guide_z = pyro.sample(guide_name, fn, infer=infer) common_batch_shape = guide_z.shape[:-1] model_zs = {} pos = 0 for site in self.prototype_sites: name = site["name"] fn = site["fn"] # Extract slice from packed sample. size = self._site_sizes[name] batch_shape = broadcast_shape(common_batch_shape, self._site_batch_shapes[name]) unconstrained_z = guide_z[..., pos: pos + size] unconstrained_z = unconstrained_z.reshape(batch_shape + fn.event_shape) pos += size # Transform to constrained space. transform = biject_to(fn.support) z = transform(unconstrained_z) log_density = transform.inv.log_abs_det_jacobian(z, unconstrained_z) log_density = sum_rightmost(log_density, log_density.dim() - z.dim() + fn.event_dim) delta_dist = dist.Delta(z, log_density=log_density, event_dim=fn.event_dim) # Replay model sample statement. with ExitStack() as stack: for frame in site["cond_indep_stack"]: plate = self.guide.plate(frame.name) if plate not in runtime._PYRO_STACK: stack.enter_context(plate) model_zs[name] = pyro.sample(name, delta_dist) return guide_z, model_zs
[docs] def map_estimate(self): """ Construct a maximum a posteriori (MAP) guide using Delta distributions. :return: A dict mapping model site name to sampled value. :rtype: dict """ return {site["name"]: self.guide.map_estimate(site["name"]) for site in self.prototype_sites}
[docs]def easy_guide(model): """ Convenience decorator to create an :class:`EasyGuide` . The following are equivalent:: # Version 1. Decorate a function. @easy_guide(model) def guide(self, foo, bar): return my_guide(foo, bar) # Version 2. Create and instantiate a subclass of EasyGuide. class Guide(EasyGuide): def guide(self, foo, bar): return my_guide(foo, bar) guide = Guide(model) :param callable model: a Pyro model. """ def decorator(fn): Guide = type(fn.__name__, (EasyGuide,), {"guide": fn}) return Guide(model) return decorator