# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import math
from abc import ABCMeta, abstractmethod
import torch
from torch.distributions import biject_to
import pyro
from pyro import poutine
from pyro.distributions import Delta
from pyro.distributions.util import copy_docs_from
from pyro.infer.autoguide.guides import AutoContinuous
from pyro.infer.autoguide.initialization import init_to_sample
from pyro.infer.trace_elbo import Trace_ELBO
[docs]def vectorize(fn, num_particles, max_plate_nesting):
def _fn(*args, **kwargs):
with pyro.plate(
"num_particles_vectorized", num_particles, dim=-max_plate_nesting - 1
):
return fn(*args, **kwargs)
return _fn
class _SVGDGuide(AutoContinuous):
"""
This modification of :class:`AutoContinuous` is used internally in the
:class:`SVGD` inference algorithm.
"""
def __init__(self, model):
super().__init__(model, init_loc_fn=init_to_sample)
def get_posterior(self, *args, **kwargs):
svgd_particles = pyro.param("svgd_particles", self._init_loc)
return Delta(svgd_particles, event_dim=1)
[docs]class SteinKernel(object, metaclass=ABCMeta):
"""
Abstract class for kernels used in the :class:`SVGD` inference algorithm.
"""
[docs] @abstractmethod
def log_kernel_and_grad(self, particles):
"""
Compute the component kernels and their gradients.
:param particles: a tensor with shape (N, D)
:returns: A pair (`log_kernel`, `kernel_grad`) where `log_kernel` is a (N, N, D)-shaped
tensor equal to the logarithm of the kernel and `kernel_grad` is a (N, N, D)-shaped
tensor where the entry (n, m, d) represents the derivative of `log_kernel` w.r.t.
x_{m,d}, where x_{m,d} is the d^th dimension of particle m.
"""
raise NotImplementedError
[docs]@copy_docs_from(SteinKernel)
class RBFSteinKernel(SteinKernel):
"""
A RBF kernel for use in the SVGD inference algorithm. The bandwidth of the kernel is chosen from the
particles using a simple heuristic as in reference [1].
:param float bandwidth_factor: Optional factor by which to scale the bandwidth, defaults to 1.0.
:ivar float ~.bandwidth_factor: Property that controls the factor by which to scale the bandwidth
at each iteration.
References
[1] "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm,"
Qiang Liu, Dilin Wang
"""
def __init__(self, bandwidth_factor=None):
"""
:param float bandwidth_factor: Optional factor by which to scale the bandwidth
"""
self.bandwidth_factor = bandwidth_factor
def _bandwidth(self, norm_sq):
"""
Compute the bandwidth along each dimension using the median pairwise squared distance between particles.
"""
num_particles = norm_sq.size(0)
index = torch.arange(num_particles)
norm_sq = norm_sq[index > index.unsqueeze(-1), ...]
median = norm_sq.median(dim=0)[0]
if self.bandwidth_factor is not None:
median = self.bandwidth_factor * median
assert median.shape == norm_sq.shape[-1:]
return median / math.log(num_particles + 1)
[docs] @torch.no_grad()
def log_kernel_and_grad(self, particles):
delta_x = particles.unsqueeze(0) - particles.unsqueeze(1) # N N D
assert delta_x.dim() == 3
norm_sq = delta_x.pow(2.0) # N N D
h = self._bandwidth(norm_sq) # D
log_kernel = -(norm_sq / h) # N N D
grad_term = 2.0 * delta_x / h # N N D
assert log_kernel.shape == grad_term.shape
return log_kernel, grad_term
@property
def bandwidth_factor(self):
return self._bandwidth_factor
@bandwidth_factor.setter
def bandwidth_factor(self, bandwidth_factor):
"""
:param float bandwidth_factor: Optional factor by which to scale the bandwidth
"""
if bandwidth_factor is not None:
assert bandwidth_factor > 0.0, "bandwidth_factor must be positive."
self._bandwidth_factor = bandwidth_factor
[docs]@copy_docs_from(SteinKernel)
class IMQSteinKernel(SteinKernel):
r"""
An IMQ (inverse multi-quadratic) kernel for use in the SVGD inference algorithm [1]. The bandwidth of the kernel
is chosen from the particles using a simple heuristic as in reference [2]. The kernel takes the form
:math:`K(x, y) = (\alpha + ||x-y||^2/h)^{\beta}`
where :math:`\alpha` and :math:`\beta` are user-specified parameters and :math:`h` is the bandwidth.
:param float alpha: Kernel hyperparameter, defaults to 0.5.
:param float beta: Kernel hyperparameter, defaults to -0.5.
:param float bandwidth_factor: Optional factor by which to scale the bandwidth, defaults to 1.0.
:ivar float ~.bandwidth_factor: Property that controls the factor by which to scale the bandwidth
at each iteration.
References
[1] "Stein Points," Wilson Ye Chen, Lester Mackey, Jackson Gorham, Francois-Xavier Briol, Chris. J. Oates.
[2] "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm," Qiang Liu, Dilin Wang
"""
def __init__(self, alpha=0.5, beta=-0.5, bandwidth_factor=None):
"""
:param float alpha: Kernel hyperparameter, defaults to 0.5.
:param float beta: Kernel hyperparameter, defaults to -0.5.
:param float bandwidth_factor: Optional factor by which to scale the bandwidth
"""
assert alpha > 0.0, "alpha must be positive."
assert beta < 0.0, "beta must be negative."
self.alpha = alpha
self.beta = beta
self.bandwidth_factor = bandwidth_factor
def _bandwidth(self, norm_sq):
"""
Compute the bandwidth along each dimension using the median pairwise squared distance between particles.
"""
num_particles = norm_sq.size(0)
index = torch.arange(num_particles)
norm_sq = norm_sq[index > index.unsqueeze(-1), ...]
median = norm_sq.median(dim=0)[0]
if self.bandwidth_factor is not None:
median = self.bandwidth_factor * median
assert median.shape == norm_sq.shape[-1:]
return median / math.log(num_particles + 1)
[docs] @torch.no_grad()
def log_kernel_and_grad(self, particles):
delta_x = particles.unsqueeze(0) - particles.unsqueeze(1) # N N D
assert delta_x.dim() == 3
norm_sq = delta_x.pow(2.0) # N N D
h = self._bandwidth(norm_sq) # D
base_term = self.alpha + norm_sq / h
log_kernel = self.beta * torch.log(base_term) # N N D
grad_term = (-2.0 * self.beta) * delta_x / h # N N D
grad_term = grad_term / base_term
assert log_kernel.shape == grad_term.shape
return log_kernel, grad_term
@property
def bandwidth_factor(self):
return self._bandwidth_factor
@bandwidth_factor.setter
def bandwidth_factor(self, bandwidth_factor):
"""
:param float bandwidth_factor: Optional factor by which to scale the bandwidth
"""
if bandwidth_factor is not None:
assert bandwidth_factor > 0.0, "bandwidth_factor must be positive."
self._bandwidth_factor = bandwidth_factor
[docs]class SVGD:
"""
A basic implementation of Stein Variational Gradient Descent as described in reference [1].
:param model: The model (callable containing Pyro primitives). Model must be fully vectorized
and may only contain continuous latent variables.
:param kernel: a SVGD compatible kernel like :class:`RBFSteinKernel`.
:param optim: A wrapper for a PyTorch optimizer.
:type optim: pyro.optim.PyroOptim
:param int num_particles: The number of particles used in SVGD.
:param int max_plate_nesting: The max number of nested :func:`pyro.plate` contexts in the model.
:param str mode: Whether to use a Kernelized Stein Discrepancy that makes use of `multivariate`
test functions (as in [1]) or `univariate` test functions (as in [2]). Defaults to `univariate`.
Example usage:
.. code-block:: python
from pyro.infer import SVGD, RBFSteinKernel
from pyro.optim import Adam
kernel = RBFSteinKernel()
adam = Adam({"lr": 0.1})
svgd = SVGD(model, kernel, adam, num_particles=50, max_plate_nesting=0)
for step in range(500):
svgd.step(model_arg1, model_arg2)
final_particles = svgd.get_named_particles()
References
[1] "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm,"
Qiang Liu, Dilin Wang
[2] "Kernelized Complete Conditional Stein Discrepancy,"
Raghav Singhal, Saad Lahlou, Rajesh Ranganath
"""
def __init__(
self, model, kernel, optim, num_particles, max_plate_nesting, mode="univariate"
):
assert callable(model)
assert isinstance(kernel, SteinKernel), "Must provide a valid SteinKernel"
assert isinstance(
optim, pyro.optim.PyroOptim
), "Must provide a valid Pyro optimizer"
assert num_particles > 1, "Must use at least two particles"
assert max_plate_nesting >= 0
assert mode in [
"univariate",
"multivariate",
], "mode must be one of (univariate, multivariate)"
self.model = vectorize(model, num_particles, max_plate_nesting)
self.kernel = kernel
self.optim = optim
self.num_particles = num_particles
self.max_plate_nesting = max_plate_nesting
self.mode = mode
self.loss = Trace_ELBO().differentiable_loss
self.guide = _SVGDGuide(self.model)
[docs] def get_named_particles(self):
"""
Create a dictionary mapping name to vectorized value, of the form ``{name: tensor}``.
The leading dimension of each tensor corresponds to particles, i.e. this creates a struct of arrays.
"""
return {
site["name"]: biject_to(site["fn"].support)(unconstrained_value)
for site, unconstrained_value in self.guide._unpack_latent(
pyro.param("svgd_particles")
)
}
[docs] @torch.no_grad()
def step(self, *args, **kwargs):
"""
Computes the SVGD gradient, passing args and kwargs to the model,
and takes a gradient step.
:return dict: A dictionary of the form {name: float}, where each float
is a mean squared gradient. This can be used to monitor the convergence of SVGD.
"""
# compute gradients of log model joint
with torch.enable_grad(), poutine.trace(param_only=True) as param_capture:
loss = self.loss(self.model, self.guide, *args, **kwargs)
loss.backward()
# get particles used in the _SVGDGuide and reshape to have num_particles leading dimension
particles = pyro.param("svgd_particles").unconstrained()
reshaped_particles = particles.reshape(self.num_particles, -1)
reshaped_particles_grad = particles.grad.reshape(self.num_particles, -1)
# compute kernel ingredients
log_kernel, kernel_grad = self.kernel.log_kernel_and_grad(reshaped_particles)
if self.mode == "multivariate":
kernel = log_kernel.sum(-1).exp()
assert kernel.shape == (self.num_particles, self.num_particles)
attractive_grad = torch.mm(kernel, reshaped_particles_grad)
repulsive_grad = torch.einsum("nm,nm...->n...", kernel, kernel_grad)
elif self.mode == "univariate":
kernel = log_kernel.exp()
assert kernel.shape == (
self.num_particles,
self.num_particles,
reshaped_particles.size(-1),
)
attractive_grad = torch.einsum(
"nmd,md->nd", kernel, reshaped_particles_grad
)
repulsive_grad = torch.einsum("nmd,nmd->nd", kernel, kernel_grad)
# combine the attractive and repulsive terms in the SVGD gradient
assert attractive_grad.shape == repulsive_grad.shape
particles.grad = (attractive_grad + repulsive_grad).reshape(
particles.shape
) / self.num_particles
# compute per-parameter mean squared gradients
squared_gradients = {
site["name"]: value.mean().item()
for site, value in self.guide._unpack_latent(particles.grad.pow(2.0))
}
# torch.optim objects gets instantiated for any params that haven't been seen yet
params = set(
site["value"].unconstrained() for site in param_capture.trace.nodes.values()
)
self.optim(params)
# zero gradients
pyro.infer.util.zero_grads(params)
# return per-parameter mean squared gradients to user
return squared_gradients