Source code for pyro.infer.autoguide.guides

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

"""
The :mod:`pyro.infer.autoguide` module provides algorithms to automatically
generate guides from simple models, for use in :class:`~pyro.infer.svi.SVI`.
For example to generate a mean field Gaussian guide::

    def model():
        ...

    guide = AutoNormal(model)  # a mean field guide
    svi = SVI(model, guide, Adam({'lr': 1e-3}), Trace_ELBO())

Automatic guides can also be combined using :func:`pyro.poutine.block` and
:class:`AutoGuideList`.
"""
import functools
import operator
import warnings
import weakref
from contextlib import ExitStack

import torch
from torch import nn
from torch.distributions import biject_to

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions import constraints
from pyro.distributions.transforms import affine_autoregressive, iterated
from pyro.distributions.util import eye_like, sum_rightmost
from pyro.infer.enum import config_enumerate
from pyro.nn.module import PyroModule, PyroParam
from pyro.ops.hessian import hessian
from pyro.ops.tensor_utils import periodic_repeat
from pyro.poutine.util import site_is_subsample

from .initialization import InitMessenger, init_to_feasible, init_to_median
from .utils import _product, deep_getattr, deep_setattr, helpful_support_errors


def prototype_hide_fn(msg):
    # Record only stochastic sites in the prototype_trace.
    return msg["type"] != "sample" or msg["is_observed"] or site_is_subsample(msg)


[docs]class AutoGuide(PyroModule): """ Base class for automatic guides. Derived classes must implement the :meth:`forward` method, with the same ``*args, **kwargs`` as the base ``model``. Auto guides can be used individually or combined in an :class:`AutoGuideList` object. :param callable model: A pyro model. :param callable create_plates: An optional function inputing the same ``*args,**kwargs`` as ``model()`` and returning a :class:`pyro.plate` or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling. """ def __init__(self, model, *, create_plates=None): super().__init__(name=type(self).__name__) self.master = None # Do not register model as submodule self._model = (model,) self.create_plates = create_plates self.prototype_trace = None self._prototype_frames = {} @property def model(self): return self._model[0] def __getstate__(self): # Do not pickle weakrefs. self._model = None self.master = None return getattr(super(), "__getstate__", self.__dict__.copy)() def __setstate__(self, state): getattr(super(), "__setstate__", self.__dict__.update)(state) assert self.master is None master_ref = weakref.ref(self) for _, mod in self.named_modules(): if mod is not self and isinstance(mod, AutoGuide): mod._update_master(master_ref) def _update_master(self, master_ref): self.master = master_ref for _, mod in self.named_modules(): if mod is not self and isinstance(mod, AutoGuide): mod._update_master(master_ref)
[docs] def call(self, *args, **kwargs): """ Method that calls :meth:`forward` and returns parameter values of the guide as a `tuple` instead of a `dict`, which is a requirement for JIT tracing. Unlike :meth:`forward`, this method can be traced by :func:`torch.jit.trace_module`. .. warning:: This method may be removed once PyTorch JIT tracer starts accepting `dict` as valid return types. See `issue <https://github.com/pytorch/pytorch/issues/27743>_`. """ result = self(*args, **kwargs) return tuple(v for _, v in sorted(result.items()))
[docs] def sample_latent(*args, **kwargs): """ Samples an encoded latent given the same ``*args, **kwargs`` as the base ``model``. """ pass
def __setattr__(self, name, value): if isinstance(value, AutoGuide): master_ref = weakref.ref(self) if self.master is None else self.master value._update_master(master_ref) super().__setattr__(name, value) def _create_plates(self, *args, **kwargs): if self.master is None: if self.create_plates is None: self.plates = {} else: plates = self.create_plates(*args, **kwargs) if isinstance(plates, pyro.plate): plates = [plates] assert all( isinstance(p, pyro.plate) for p in plates ), "create_plates() returned a non-plate" self.plates = {p.name: p for p in plates} for name, frame in sorted(self._prototype_frames.items()): if name not in self.plates: full_size = frame.full_size or frame.size self.plates[name] = pyro.plate( name, full_size, dim=frame.dim, subsample_size=frame.size ) else: assert ( self.create_plates is None ), "Cannot pass create_plates() to non-master guide" self.plates = self.master().plates return self.plates _prototype_hide_fn = staticmethod(prototype_hide_fn) def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = poutine.block(self.model, self._prototype_hide_fn) self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( *args, **kwargs ) if self.master is not None: self.master()._check_prototype(self.prototype_trace) self._prototype_frames = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): for frame in site["cond_indep_stack"]: if frame.vectorized: self._prototype_frames[frame.name] = frame else: raise NotImplementedError( "AutoGuide does not support sequential pyro.plate" )
[docs] def median(self, *args, **kwargs): """ Returns the posterior median value of each latent variable. :return: A dict mapping sample site name to median tensor. :rtype: dict """ raise NotImplementedError
[docs]class AutoGuideList(AutoGuide, nn.ModuleList): """ Container class to combine multiple automatic or custom guides. Example usage:: guide = AutoGuideList(my_model) guide.append(AutoDiagonalNormal(poutine.block(model, hide=["assignment"]))) guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) svi = SVI(model, guide, optim, Trace_ELBO()) :param callable model: a Pyro model """ def _check_prototype(self, part_trace): for name, part_site in part_trace.iter_stochastic_nodes(): self_site = self.prototype_trace.nodes[name] assert part_site["fn"].batch_shape == self_site["fn"].batch_shape assert part_site["fn"].event_shape == self_site["fn"].event_shape assert part_site["value"].shape == self_site["value"].shape
[docs] def append(self, part): """ Add an automatic or custom guide for part of the model. The guide should have been created by blocking the model to restrict to a subset of sample sites. No two parts should operate on any one sample site. :param part: a partial guide to add :type part: AutoGuide or callable """ if not isinstance(part, AutoGuide): part = AutoCallable(self.model, part) if part.master is not None: raise RuntimeError( "The module `{}` is already added.".format(self._pyro_name) ) setattr(self, str(len(self)), part)
[docs] def add(self, part): """Deprecated alias for :meth:`append`.""" warnings.warn( "The method `.add` has been deprecated in favor of `.append`.", DeprecationWarning, ) self.append(part)
[docs] def forward(self, *args, **kwargs): """ A composite guide with the same ``*args, **kwargs`` as the base ``model``. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__`. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) # create all plates self._create_plates(*args, **kwargs) # run slave guides result = {} for part in self: result.update(part(*args, **kwargs)) return result
[docs] def median(self, *args, **kwargs): """ Returns the posterior median value of each latent variable. :return: A dict mapping sample site name to median tensor. :rtype: dict """ result = {} for part in self: result.update(part.median(*args, **kwargs)) return result
[docs] def quantiles(self, quantiles, *args, **kwargs): """ Returns the posterior quantile values of each latent variable. :param list quantiles: A list of requested quantiles between 0 and 1. :returns: A dict mapping sample site name to quantiles tensor. :rtype: dict """ result = {} for part in self: result.update(part.quantiles(quantiles, *args, **kwargs)) return result
[docs]class AutoCallable(AutoGuide): """ :class:`AutoGuide` wrapper for simple callable guides. This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:: def my_local_guide(*args, **kwargs): ... guide = AutoGuideList(model) guide.append(AutoDelta(poutine.block(model, expose=['my_global_param'])) guide.append(my_local_guide) # automatically wrapped in an AutoCallable To specify a median callable, you can instead:: def my_local_median(*args, **kwargs) ... guide.append(AutoCallable(model, my_local_guide, my_local_median)) For more complex guides that need e.g. access to plates, users should instead subclass ``AutoGuide``. :param callable model: a Pyro model :param callable guide: a Pyro guide (typically over only part of the model) :param callable median: an optional callable returning a dict mapping sample site name to computed median tensor. """ def __init__(self, model, guide, median=lambda *args, **kwargs: {}): super().__init__(model) self._guide = guide self.median = median
[docs] def forward(self, *args, **kwargs): result = self._guide(*args, **kwargs) return {} if result is None else result
[docs]class AutoDelta(AutoGuide): """ This implementation of :class:`AutoGuide` uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. .. note:: This class does MAP inference in constrained space. Usage:: guide = AutoDelta(model) svi = SVI(model, guide, ...) Latent variables are initialized using ``init_loc_fn()``. To change the default behavior, create a custom ``init_loc_fn()`` as described in :ref:`autoguide-initialization` , for example:: def my_init_fn(site): if site["name"] == "level": return torch.tensor([-1., 0., 1.]) if site["name"] == "concentration": return torch.ones(k) return init_to_sample(site) :param callable model: A Pyro model. :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. :param callable create_plates: An optional function inputing the same ``*args,**kwargs`` as ``model()`` and returning a :class:`pyro.plate` or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling. """ def __init__(self, model, init_loc_fn=init_to_median, *, create_plates=None): self.init_loc_fn = init_loc_fn model = InitMessenger(self.init_loc_fn)(model) super().__init__(model, create_plates=create_plates) def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params for name, site in self.prototype_trace.iter_stochastic_nodes(): value = site["value"].detach() event_dim = site["fn"].event_dim # If subsampling, repeat init_value to full size. for frame in site["cond_indep_stack"]: full_size = frame.full_size or frame.size if full_size != frame.size: dim = frame.dim - event_dim value = periodic_repeat(value, full_size, dim).contiguous() value = PyroParam(value, site["fn"].support, event_dim) with helpful_support_errors(site): deep_setattr(self, name, value)
[docs] def forward(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__`. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): with ExitStack() as stack: for frame in site["cond_indep_stack"]: if frame.vectorized: stack.enter_context(plates[frame.name]) attr_get = operator.attrgetter(name) result[name] = pyro.sample( name, dist.Delta(attr_get(self), event_dim=site["fn"].event_dim) ) return result
[docs] @torch.no_grad() def median(self, *args, **kwargs): """ Returns the posterior median value of each latent variable. :return: A dict mapping sample site name to median tensor. :rtype: dict """ result = self(*args, **kwargs) return {k: v.detach() for k, v in result.items()}
[docs]class AutoNormal(AutoGuide): """This implementation of :class:`AutoGuide` uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. It should be equivalent to :class: `AutoDiagonalNormal` , but with more convenient site names and with better support for :class:`~pyro.infer.trace_mean_field_elbo.TraceMeanField_ELBO` . In :class:`AutoDiagonalNormal` , if your model has N named parameters with dimensions k_i and sum k_i = D, you get a single vector of length D for your mean, and a single vector of length D for sigmas. This guide gives you N distinct normals that you can call by name. Usage:: guide = AutoNormal(model) svi = SVI(model, guide, ...) :param callable model: A Pyro model. :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. :param float init_scale: Initial scale for the standard deviation of each (unconstrained transformed) latent variable. :param callable create_plates: An optional function inputing the same ``*args,**kwargs`` as ``model()`` and returning a :class:`pyro.plate` or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling. """ scale_constraint = constraints.softplus_positive def __init__( self, model, *, init_loc_fn=init_to_feasible, init_scale=0.1, create_plates=None ): self.init_loc_fn = init_loc_fn if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) self._init_scale = init_scale model = InitMessenger(self.init_loc_fn)(model) super().__init__(model, create_plates=create_plates) def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self._event_dims = {} self.locs = PyroModule() self.scales = PyroModule() # Initialize guide params for name, site in self.prototype_trace.iter_stochastic_nodes(): # Collect unconstrained event_dims, which may differ from constrained event_dims. with helpful_support_errors(site): init_loc = ( biject_to(site["fn"].support).inv(site["value"].detach()).detach() ) event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim() self._event_dims[name] = event_dim # If subsampling, repeat init_value to full size. for frame in site["cond_indep_stack"]: full_size = frame.full_size or frame.size if full_size != frame.size: dim = frame.dim - event_dim init_loc = periodic_repeat(init_loc, full_size, dim).contiguous() init_scale = torch.full_like(init_loc, self._init_scale) deep_setattr( self.locs, name, PyroParam(init_loc, constraints.real, event_dim) ) deep_setattr( self.scales, name, PyroParam(init_scale, self.scale_constraint, event_dim), ) def _get_loc_and_scale(self, name): site_loc = deep_getattr(self.locs, name) site_scale = deep_getattr(self.scales, name) return site_loc, site_scale
[docs] def forward(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__`. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): transform = biject_to(site["fn"].support) with ExitStack() as stack: for frame in site["cond_indep_stack"]: if frame.vectorized: stack.enter_context(plates[frame.name]) site_loc, site_scale = self._get_loc_and_scale(name) unconstrained_latent = pyro.sample( name + "_unconstrained", dist.Normal( site_loc, site_scale, ).to_event(self._event_dims[name]), infer={"is_auxiliary": True}, ) value = transform(unconstrained_latent) if poutine.get_mask() is False: log_density = 0.0 else: log_density = transform.inv.log_abs_det_jacobian( value, unconstrained_latent, ) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + site["fn"].event_dim, ) delta_dist = dist.Delta( value, log_density=log_density, event_dim=site["fn"].event_dim, ) result[name] = pyro.sample(name, delta_dist) return result
[docs] @torch.no_grad() def median(self, *args, **kwargs): """ Returns the posterior median value of each latent variable. :return: A dict mapping sample site name to median tensor. :rtype: dict """ medians = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): site_loc, _ = self._get_loc_and_scale(name) median = biject_to(site["fn"].support)(site_loc) if median is site_loc: median = median.clone() medians[name] = median return medians
[docs] @torch.no_grad() def quantiles(self, quantiles, *args, **kwargs): """ Returns posterior quantiles each latent variable. Example:: print(guide.quantiles([0.05, 0.5, 0.95])) :param quantiles: A list of requested quantiles between 0 and 1. :type quantiles: torch.Tensor or list :return: A dict mapping sample site name to a tensor of quantile values. :rtype: dict """ results = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): site_loc, site_scale = self._get_loc_and_scale(name) site_quantiles = torch.tensor( quantiles, dtype=site_loc.dtype, device=site_loc.device ) site_quantiles = site_quantiles.reshape((-1,) + (1,) * site_loc.dim()) site_quantiles_values = dist.Normal(site_loc, site_scale).icdf( site_quantiles ) constrained_site_quantiles = biject_to(site["fn"].support)( site_quantiles_values ) results[name] = constrained_site_quantiles return results
[docs]class AutoContinuous(AutoGuide): """ Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1]. This uses :mod:`torch.distributions.transforms` to transform each constrained latent variable to an unconstrained space, then concatenate all variables into a single unconstrained latent variable. Each derived class implements a :meth:`get_posterior` method returning a distribution over this single unconstrained latent variable. Assumes model structure and latent dimension are fixed, and all latent variables are continuous. :param callable model: a Pyro model Reference: [1] `Automatic Differentiation Variational Inference`, Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei :param callable model: A Pyro model. :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. """ def __init__(self, model, init_loc_fn=init_to_median): model = InitMessenger(init_loc_fn)(model) super().__init__(model) def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self._unconstrained_shapes = {} self._cond_indep_stacks = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): # Collect the shapes of unconstrained values. # These may differ from the shapes of constrained values. with helpful_support_errors(site): self._unconstrained_shapes[name] = ( biject_to(site["fn"].support).inv(site["value"]).shape ) # Collect independence contexts. self._cond_indep_stacks[name] = site["cond_indep_stack"] self.latent_dim = sum( _product(shape) for shape in self._unconstrained_shapes.values() ) if self.latent_dim == 0: raise RuntimeError( "{} found no latent variables; Use an empty guide instead".format( type(self).__name__ ) ) def _init_loc(self): """ Creates an initial latent vector using a per-site init function. """ parts = [] for name, site in self.prototype_trace.iter_stochastic_nodes(): constrained_value = site["value"].detach() unconstrained_value = biject_to(site["fn"].support).inv(constrained_value) parts.append(unconstrained_value.reshape(-1)) latent = torch.cat(parts) assert latent.size() == (self.latent_dim,) return latent
[docs] def get_base_dist(self): """ Returns the base distribution of the posterior when reparameterized as a :class:`~pyro.distributions.TransformedDistribution`. This should not depend on the model's `*args, **kwargs`. .. code-block:: python posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs)) :return: :class:`~pyro.distributions.TorchDistribution` instance representing the base distribution. """ raise NotImplementedError
[docs] def get_transform(self, *args, **kwargs): """ Returns the transform applied to the base distribution when the posterior is reparameterized as a :class:`~pyro.distributions.TransformedDistribution`. This may depend on the model's `*args, **kwargs`. .. code-block:: python posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs)) :return: a :class:`~torch.distributions.Transform` instance. """ raise NotImplementedError
[docs] def get_posterior(self, *args, **kwargs): """ Returns the posterior distribution. """ base_dist = self.get_base_dist() transform = self.get_transform(*args, **kwargs) return dist.TransformedDistribution(base_dist, transform)
[docs] def sample_latent(self, *args, **kwargs): """ Samples an encoded latent given the same ``*args, **kwargs`` as the base ``model``. """ pos_dist = self.get_posterior(*args, **kwargs) return pyro.sample( "_{}_latent".format(self._pyro_name), pos_dist, infer={"is_auxiliary": True} )
def _unpack_latent(self, latent): """ Unpacks a packed latent tensor, iterating over tuples of the form:: (site, unconstrained_value) """ batch_shape = latent.shape[ :-1 ] # for plates outside of _setup_prototype, e.g. parallel particles pos = 0 for name, site in self.prototype_trace.iter_stochastic_nodes(): constrained_shape = site["value"].shape unconstrained_shape = self._unconstrained_shapes[name] size = _product(unconstrained_shape) event_dim = ( site["fn"].event_dim + len(unconstrained_shape) - len(constrained_shape) ) unconstrained_shape = torch.broadcast_shapes( unconstrained_shape, batch_shape + (1,) * event_dim ) unconstrained_value = latent[..., pos : pos + size].view( unconstrained_shape ) yield site, unconstrained_value pos += size if not torch._C._get_tracing_state(): assert pos == latent.size(-1)
[docs] def forward(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__`. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) latent = self.sample_latent(*args, **kwargs) plates = self._create_plates(*args, **kwargs) # unpack continuous latent samples result = {} for site, unconstrained_value in self._unpack_latent(latent): name = site["name"] transform = biject_to(site["fn"].support) value = transform(unconstrained_value) if poutine.get_mask() is False: log_density = 0.0 else: log_density = transform.inv.log_abs_det_jacobian( value, unconstrained_value, ) log_density = sum_rightmost( log_density, log_density.dim() - value.dim() + site["fn"].event_dim, ) delta_dist = dist.Delta( value, log_density=log_density, event_dim=site["fn"].event_dim, ) with ExitStack() as stack: for frame in self._cond_indep_stacks[name]: stack.enter_context(plates[frame.name]) result[name] = pyro.sample(name, delta_dist) return result
def _loc_scale(self, *args, **kwargs): """ :returns: a tuple ``(loc, scale)`` used by :meth:`median` and :meth:`quantiles` """ raise NotImplementedError
[docs] @torch.no_grad() def median(self, *args, **kwargs): """ Returns the posterior median value of each latent variable. :return: A dict mapping sample site name to median tensor. :rtype: dict """ loc, _ = self._loc_scale(*args, **kwargs) loc = loc.detach() return { site["name"]: biject_to(site["fn"].support)(unconstrained_value) for site, unconstrained_value in self._unpack_latent(loc) }
[docs] @torch.no_grad() def quantiles(self, quantiles, *args, **kwargs): """ Returns posterior quantiles each latent variable. Example:: print(guide.quantiles([0.05, 0.5, 0.95])) :param quantiles: A list of requested quantiles between 0 and 1. :type quantiles: torch.Tensor or list :return: A dict mapping sample site name to a tensor of quantile values. :rtype: dict """ loc, scale = self._loc_scale(*args, **kwargs) quantiles = torch.tensor( quantiles, dtype=loc.dtype, device=loc.device ).unsqueeze(-1) latents = dist.Normal(loc, scale).icdf(quantiles) result = {} for latent in latents: for site, unconstrained_value in self._unpack_latent(latent): result.setdefault(site["name"], []).append( biject_to(site["fn"].support)(unconstrained_value) ) result = {k: torch.stack(v) for k, v in result.items()} return result
[docs]class AutoMultivariateNormal(AutoContinuous): """ This implementation of :class:`AutoContinuous` uses a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. Usage:: guide = AutoMultivariateNormal(model) svi = SVI(model, guide, ...) By default the mean vector is initialized by ``init_loc_fn()`` and the Cholesky factor is initialized to the identity times a small factor. :param callable model: A generative model. :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. :param float init_scale: Initial scale for the standard deviation of each (unconstrained transformed) latent variable. """ scale_constraint = constraints.softplus_positive scale_tril_constraint = constraints.unit_lower_cholesky def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1): if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) self._init_scale = init_scale super().__init__(model, init_loc_fn=init_loc_fn) def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params self.loc = nn.Parameter(self._init_loc()) self.scale = PyroParam( torch.full_like(self.loc, self._init_scale), self.scale_constraint ) self.scale_tril = PyroParam( eye_like(self.loc, self.latent_dim), self.scale_tril_constraint )
[docs] def get_base_dist(self): return dist.Normal( torch.zeros_like(self.loc), torch.ones_like(self.loc) ).to_event(1)
[docs] def get_transform(self, *args, **kwargs): scale_tril = self.scale[..., None] * self.scale_tril return dist.transforms.LowerCholeskyAffine(self.loc, scale_tril=scale_tril)
[docs] def get_posterior(self, *args, **kwargs): """ Returns a MultivariateNormal posterior distribution. """ scale_tril = self.scale[..., None] * self.scale_tril return dist.MultivariateNormal(self.loc, scale_tril=scale_tril)
def _loc_scale(self, *args, **kwargs): scale_tril = self.scale[..., None] * self.scale_tril scale = scale_tril.pow(2).sum(-1).sqrt() assert scale.shape == self.loc.shape return self.loc, scale
[docs]class AutoDiagonalNormal(AutoContinuous): """ This implementation of :class:`AutoContinuous` uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. Usage:: guide = AutoDiagonalNormal(model) svi = SVI(model, guide, ...) By default the mean vector is initialized to zero and the scale is initialized to the identity times a small factor. :param callable model: A generative model. :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. :param float init_scale: Initial scale for the standard deviation of each (unconstrained transformed) latent variable. """ scale_constraint = constraints.softplus_positive def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1): if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) self._init_scale = init_scale super().__init__(model, init_loc_fn=init_loc_fn) def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params self.loc = nn.Parameter(self._init_loc()) self.scale = PyroParam( self.loc.new_full((self.latent_dim,), self._init_scale), self.scale_constraint, )
[docs] def get_base_dist(self): return dist.Normal( torch.zeros_like(self.loc), torch.ones_like(self.loc) ).to_event(1)
[docs] def get_transform(self, *args, **kwargs): return dist.transforms.AffineTransform(self.loc, self.scale)
[docs] def get_posterior(self, *args, **kwargs): """ Returns a diagonal Normal posterior distribution. """ return dist.Normal(self.loc, self.scale).to_event(1)
def _loc_scale(self, *args, **kwargs): return self.loc, self.scale
[docs]class AutoLowRankMultivariateNormal(AutoContinuous): """ This implementation of :class:`AutoContinuous` uses a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. Usage:: guide = AutoLowRankMultivariateNormal(model, rank=10) svi = SVI(model, guide, ...) By default the ``cov_diag`` is initialized to a small constant and the ``cov_factor`` is initialized randomly such that on average ``cov_factor.matmul(cov_factor.t())`` has the same scale as ``cov_diag``. :param callable model: A generative model. :param rank: The rank of the low-rank part of the covariance matrix. Defaults to approximately ``sqrt(latent dim)``. :type rank: int or None :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. :param float init_scale: Approximate initial scale for the standard deviation of each (unconstrained transformed) latent variable. """ scale_constraint = constraints.softplus_positive def __init__(self, model, init_loc_fn=init_to_median, init_scale=0.1, rank=None): if not isinstance(init_scale, float) or not (init_scale > 0): raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) if not (rank is None or isinstance(rank, int) and rank > 0): raise ValueError("Expected rank > 0 but got {}".format(rank)) self._init_scale = init_scale self.rank = rank super().__init__(model, init_loc_fn=init_loc_fn) def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params self.loc = nn.Parameter(self._init_loc()) if self.rank is None: self.rank = int(round(self.latent_dim**0.5)) self.scale = PyroParam( self.loc.new_full((self.latent_dim,), 0.5**0.5 * self._init_scale), constraint=self.scale_constraint, ) self.cov_factor = nn.Parameter( self.loc.new_empty(self.latent_dim, self.rank).normal_( 0, 1 / self.rank**0.5 ) )
[docs] def get_posterior(self, *args, **kwargs): """ Returns a LowRankMultivariateNormal posterior distribution. """ scale = self.scale cov_factor = self.cov_factor * scale.unsqueeze(-1) cov_diag = scale * scale return dist.LowRankMultivariateNormal(self.loc, cov_factor, cov_diag)
def _loc_scale(self, *args, **kwargs): scale = self.scale * (self.cov_factor.pow(2).sum(-1) + 1).sqrt() return self.loc, scale
[docs]class AutoNormalizingFlow(AutoContinuous): """ This implementation of :class:`AutoContinuous` uses a Diagonal Normal distribution transformed via a sequence of bijective transforms (e.g. various :mod:`~pyro.distributions.TransformModule` subclasses) to construct a guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. Usage:: transform_init = partial(iterated, block_autoregressive, repeats=2) guide = AutoNormalizingFlow(model, transform_init) svi = SVI(model, guide, ...) :param callable model: a generative model :param init_transform_fn: a callable which when provided with the latent dimension returns an instance of :class:`~torch.distributions.Transform` , or :class:`~pyro.distributions.TransformModule` if the transform has trainable params. """ def __init__(self, model, init_transform_fn): super().__init__(model, init_loc_fn=init_to_feasible) self._init_transform_fn = init_transform_fn self.transform = None self._prototype_tensor = torch.tensor(0.0)
[docs] def get_base_dist(self): loc = self._prototype_tensor.new_zeros(1) scale = self._prototype_tensor.new_ones(1) return dist.Normal(loc, scale).expand([self.latent_dim]).to_event(1)
[docs] def get_transform(self, *args, **kwargs): return self.transform
[docs] def get_posterior(self, *args, **kwargs): if self.transform is None: self.transform = self._init_transform_fn(self.latent_dim) # Update prototype tensor in case transform parameters # device/dtype is not the same as default tensor type. for _, p in self.named_pyro_params(): self._prototype_tensor = p break return super().get_posterior(*args, **kwargs)
[docs]class AutoIAFNormal(AutoNormalizingFlow): """ This implementation of :class:`AutoContinuous` uses a Diagonal Normal distribution transformed via a :class:`~pyro.distributions.transforms.AffineAutoregressive` to construct a guide over the entire latent space. The guide does not depend on the model's ``*args, **kwargs``. Usage:: guide = AutoIAFNormal(model, hidden_dim=latent_dim) svi = SVI(model, guide, ...) :param callable model: a generative model :param list[int] hidden_dim: number of hidden dimensions in the IAF :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. .. warning:: This argument is only to preserve backwards compatibility and has no effect in practice. :param int num_transforms: number of :class:`~pyro.distributions.transforms.AffineAutoregressive` transforms to use in sequence. :param init_transform_kwargs: other keyword arguments taken by :func:`~pyro.distributions.transforms.affine_autoregressive`. """ def __init__( self, model, hidden_dim=None, init_loc_fn=None, num_transforms=1, **init_transform_kwargs, ): if init_loc_fn: warnings.warn( "The `init_loc_fn` argument to AutoIAFNormal is not used in practice. " "Please consider removing, as this may be removed in a future release.", category=FutureWarning, ) super().__init__( model, init_transform_fn=functools.partial( iterated, num_transforms, affine_autoregressive, hidden_dims=hidden_dim, **init_transform_kwargs, ), )
[docs]class AutoLaplaceApproximation(AutoContinuous): r""" Laplace approximation (quadratic approximation) approximates the posterior :math:`\log p(z | x)` by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of :math:`-\log p(x, z)` at the MAP point of `z`. Usage:: delta_guide = AutoLaplaceApproximation(model) svi = SVI(model, delta_guide, ...) # ...then train the delta_guide... guide = delta_guide.laplace_approximation() By default the mean vector is initialized to an empirical prior median. :param callable model: a generative model :param callable init_loc_fn: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. """ def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) # Initialize guide params self.loc = nn.Parameter(self._init_loc())
[docs] def get_posterior(self, *args, **kwargs): """ Returns a Delta posterior distribution for MAP inference. """ return dist.Delta(self.loc).to_event(1)
[docs] def laplace_approximation(self, *args, **kwargs): """ Returns a :class:`AutoMultivariateNormal` instance whose posterior's `loc` and `scale_tril` are given by Laplace approximation. """ guide_trace = poutine.trace(self).get_trace(*args, **kwargs) model_trace = poutine.trace( poutine.replay(self.model, trace=guide_trace) ).get_trace(*args, **kwargs) loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum() H = hessian(loss, self.loc) with torch.no_grad(): loc = self.loc.detach() cov = H.inverse() scale = cov.diagonal().sqrt() cov /= scale[:, None] cov /= scale[None, :] scale_tril = torch.linalg.cholesky(cov) gaussian_guide = AutoMultivariateNormal(self.model) gaussian_guide._setup_prototype(*args, **kwargs) # Set detached loc, scale, scale_tril parameters as computed above. del gaussian_guide.loc del gaussian_guide.scale del gaussian_guide.scale_tril gaussian_guide.register_buffer("loc", loc) gaussian_guide.register_buffer("scale", scale) gaussian_guide.register_buffer("scale_tril", scale_tril) return gaussian_guide
[docs]class AutoDiscreteParallel(AutoGuide): """ A discrete mean-field guide that learns a latent discrete distribution for each discrete site in the model. """ def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = poutine.block(config_enumerate(self.model), self._prototype_hide_fn) self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( *args, **kwargs ) if self.master is not None: self.master()._check_prototype(self.prototype_trace) self._discrete_sites = [] self._cond_indep_stacks = {} self._prototype_frames = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): if site["infer"].get("enumerate") != "parallel": raise NotImplementedError( 'Expected sample site "{}" to be discrete and ' "configured for parallel enumeration".format(name) ) # collect discrete sample sites fn = site["fn"] Dist = type(fn) if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical): params = [ ("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"]) ] else: raise NotImplementedError("{} is not supported".format(Dist.__name__)) self._discrete_sites.append((site, Dist, params)) # collect independence contexts self._cond_indep_stacks[name] = site["cond_indep_stack"] for frame in site["cond_indep_stack"]: if frame.vectorized: self._prototype_frames[frame.name] = frame else: raise NotImplementedError( "AutoDiscreteParallel does not support sequential pyro.plate" ) # Initialize guide params for site, Dist, param_spec in self._discrete_sites: name = site["name"] for param_name, param_init, param_constraint in param_spec: deep_setattr( self, "{}_{}".format(name, param_name), PyroParam(param_init, constraint=param_constraint), )
[docs] def forward(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__`. :return: A dict mapping sample site name to sampled value. :rtype: dict """ # if we've never run the model before, do so now so we can inspect the model structure if self.prototype_trace is None: self._setup_prototype(*args, **kwargs) plates = self._create_plates(*args, **kwargs) # enumerate discrete latent samples result = {} for site, Dist, param_spec in self._discrete_sites: name = site["name"] dist_params = { param_name: operator.attrgetter("{}_{}".format(name, param_name))(self) for param_name, param_init, param_constraint in param_spec } discrete_dist = Dist(**dist_params) with ExitStack() as stack: for frame in self._cond_indep_stacks[name]: stack.enter_context(plates[frame.name]) result[name] = pyro.sample( name, discrete_dist, infer={"enumerate": "parallel"} ) return result