Source code for pyro.infer.svgd

from abc import ABCMeta, abstractmethod
import math

import torch
from torch.distributions import biject_to

import pyro
from pyro import poutine
from pyro.distributions import Delta
from pyro.infer.trace_elbo import Trace_ELBO
from pyro.infer.autoguide.guides import AutoContinuous
from pyro.infer.autoguide.initialization import init_to_sample
from pyro.distributions.util import copy_docs_from


[docs]def vectorize(fn, num_particles, max_plate_nesting): def _fn(*args, **kwargs): with pyro.plate("num_particles_vectorized", num_particles, dim=-max_plate_nesting - 1): return fn(*args, **kwargs) return _fn
class _SVGDGuide(AutoContinuous): """ This modification of :class:`AutoContinuous` is used internally in the :class:`SVGD` inference algorithm. """ def __init__(self, model): super(_SVGDGuide, self).__init__(model, prefix="svgd", init_loc_fn=init_to_sample) def get_posterior(self, *args, **kwargs): svgd_particles = pyro.param("svgd_particles", self._init_loc) return Delta(svgd_particles, event_dim=1)
[docs]class SteinKernel(object, metaclass=ABCMeta): """ Abstract class for kernels used in the :class:`SVGD` inference algorithm. """
[docs] @abstractmethod def log_kernel_and_grad(self, particles): """ Compute the component kernels and their gradients. :param particles: a tensor with shape (N, D) :returns: A pair (`log_kernel`, `kernel_grad`) where `log_kernel` is a (N, N, D)-shaped tensor equal to the logarithm of the kernel and `kernel_grad` is a (N, N, D)-shaped tensor where the entry (n, m, d) represents the derivative of `log_kernel` w.r.t. x_{m,d}, where x_{m,d} is the d^th dimension of particle m. """ raise NotImplementedError
[docs]@copy_docs_from(SteinKernel) class RBFSteinKernel(SteinKernel): """ A RBF kernel for use in the SVGD inference algorithm. The bandwidth of the kernel is chosen from the particles using a simple heuristic as in reference [1]. :param float bandwidth_factor: Optional factor by which to scale the bandwidth, defaults to 1.0. :ivar float bandwidth_factor: Property that controls the factor by which to scale the bandwidth at each iteration. References [1] "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm," Qiang Liu, Dilin Wang """ def __init__(self, bandwidth_factor=None): """ :param float bandwidth_factor: Optional factor by which to scale the bandwidth """ self._bandwidth_factor = bandwidth_factor def _bandwidth(self, norm_sq): """ Compute the bandwidth along each dimension using the median pairwise squared distance between particles. """ num_particles = norm_sq.size(0) index = torch.arange(num_particles) norm_sq = norm_sq[index > index.unsqueeze(-1), ...] median = norm_sq.median(dim=0)[0] if self.bandwidth_factor is not None: median = self.bandwidth_factor * median assert median.shape == norm_sq.shape[-1:] return median / math.log(num_particles + 1)
[docs] @torch.no_grad() def log_kernel_and_grad(self, particles): delta_x = particles.unsqueeze(0) - particles.unsqueeze(1) # N N D assert delta_x.dim() == 3 norm_sq = delta_x.pow(2.0) # N N D h = self._bandwidth(norm_sq) # D log_kernel = -(norm_sq / h) # N N D grad_term = 2.0 * delta_x / h # N N D assert log_kernel.shape == grad_term.shape return log_kernel, grad_term
@property def bandwidth_factor(self): return self._bandwidth_factor @bandwidth_factor.setter def bandwidth_factor(self, bandwidth_factor): """ :param float bandwidth_factor: Optional factor by which to scale the bandwidth """ self._bandwidth_factor = bandwidth_factor
[docs]class SVGD(object): """ A basic implementation of Stein Variational Gradient Descent as described in reference [1]. :param model: The model (callable containing Pyro primitives). Model must be fully vectorized and may only contain continuous latent variables. :param kernel: a SVGD compatible kernel like :class:`RBFSteinKernel`. :param optim: A wrapper a for a PyTorch optimizer. :type optim: pyro.optim.PyroOptim :param int num_particles: The number of particles used in SVGD. :param int max_plate_nesting: The max number of nested :func:`pyro.plate` contexts in the model. :param str mode: Whether to use a Kernelized Stein Discrepancy that makes use of `multivariate` test functions (as in [1]) or `univariate` test functions (as in [2]). Defaults to `univariate`. Example usage: .. code-block:: python from pyro.infer import SVGD, RBFSteinKernel from pyro.optim import Adam kernel = RBFSteinKernel() adam = Adam({"lr": 0.1}) svgd = SVGD(model, kernel, adam, num_particles=50, max_plate_nesting=0) for step in range(500): svgd.step(model_arg1, model_arg2) final_particles = svgd.get_named_particles() References [1] "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm," Qiang Liu, Dilin Wang [2] "Kernelized Complete Conditional Stein Discrepancy," Raghav Singhal, Saad Lahlou, Rajesh Ranganath """ def __init__(self, model, kernel, optim, num_particles, max_plate_nesting, mode="univariate"): assert callable(model) assert isinstance(kernel, SteinKernel), "Must provide a valid SteinKernel" assert isinstance(optim, pyro.optim.PyroOptim), "Must provide a valid Pyro optimizer" assert num_particles > 1, "Must use at least two particles" assert max_plate_nesting >= 0 assert mode in ['univariate', 'multivariate'], "mode must be one of (univariate, multivariate)" self.model = vectorize(model, num_particles, max_plate_nesting) self.kernel = kernel self.optim = optim self.num_particles = num_particles self.max_plate_nesting = max_plate_nesting self.mode = mode self.loss = Trace_ELBO().differentiable_loss self.guide = _SVGDGuide(self.model)
[docs] def get_named_particles(self): """ Create a dictionary mapping name to vectorized value, of the form ``{name: tensor}``. The leading dimension of each tensor corresponds to particles, i.e. this creates a struct of arrays. """ return {site["name"]: biject_to(site["fn"].support)(unconstrained_value) for site, unconstrained_value in self.guide._unpack_latent(pyro.param("svgd_particles"))}
[docs] @torch.no_grad() def step(self, *args, **kwargs): """ Computes the SVGD gradient, passing args and kwargs to the model, and takes a gradient step. :return dict: A dictionary of the form {name: float}, where each float is a mean squared gradient. This can be used to monitor the convergence of SVGD. """ # compute gradients of log model joint with torch.enable_grad(), poutine.trace(param_only=True) as param_capture: loss = self.loss(self.model, self.guide, *args, **kwargs) loss.backward() # get particles used in the _SVGDGuide and reshape to have num_particles leading dimension particles = pyro.param("svgd_particles").unconstrained() reshaped_particles = particles.reshape(self.num_particles, -1) reshaped_particles_grad = particles.grad.reshape(self.num_particles, -1) # compute kernel ingredients log_kernel, kernel_grad = self.kernel.log_kernel_and_grad(reshaped_particles) if self.mode == "multivariate": kernel = log_kernel.sum(-1).exp() assert kernel.shape == (self.num_particles, self.num_particles) attractive_grad = torch.mm(kernel, reshaped_particles_grad) repulsive_grad = torch.einsum("nm,nm...->n...", kernel, kernel_grad) elif self.mode == "univariate": kernel = log_kernel.exp() assert kernel.shape == (self.num_particles, self.num_particles, reshaped_particles.size(-1)) attractive_grad = torch.einsum("nmd,md->nd", kernel, reshaped_particles_grad) repulsive_grad = torch.einsum("nmd,nmd->nd", kernel, kernel_grad) # combine the attractive and repulsive terms in the SVGD gradient assert attractive_grad.shape == repulsive_grad.shape particles.grad = (attractive_grad + repulsive_grad).reshape(particles.shape) / self.num_particles # compute per-parameter mean squared gradients squared_gradients = {site["name"]: value.mean().item() for site, value in self.guide._unpack_latent(particles.grad.pow(2.0))} # torch.optim objects gets instantiated for any params that haven't been seen yet params = set(site["value"].unconstrained() for site in param_capture.trace.nodes.values()) self.optim(params) # zero gradients pyro.infer.util.zero_grads(params) # return per-parameter mean squared gradients to user return squared_gradients