Source code for pyro.infer.mcmc.hmc

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import math
from collections import OrderedDict

import torch

import pyro
import pyro.distributions as dist
from pyro.distributions.testing.fakes import NonreparameterizedNormal
from pyro.distributions.util import scalar_like
from pyro.infer.autoguide import init_to_uniform
from pyro.infer.mcmc.adaptation import WarmupAdapter
from pyro.infer.mcmc.mcmc_kernel import MCMCKernel
from pyro.infer.mcmc.util import initialize_model
from pyro.ops.integrator import _EXCEPTION_HANDLERS, potential_grad, velocity_verlet
from pyro.util import optional, torch_isnan


[docs]class HMC(MCMCKernel): r""" Simple Hamiltonian Monte Carlo kernel, where ``step_size`` and ``num_steps`` need to be explicitly specified by the user. **References** [1] `MCMC Using Hamiltonian Dynamics`, Radford M. Neal :param model: Python callable containing Pyro primitives. :param potential_fn: Python callable calculating potential energy with input is a dict of real support parameters. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param float trajectory_length: Length of a MCMC trajectory. If not specified, it will be set to ``step_size x num_steps``. In case ``num_steps`` is not specified, it will be set to :math:`2\pi`. :param int num_steps: The number of discrete steps over which to simulate Hamiltonian dynamics. The state at the end of the trajectory is returned as the proposal. This value is always equal to ``int(trajectory_length / step_size)``. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool full_mass: A flag to decide if mass matrix is dense or diagonal. :param dict transforms: Optional dictionary that specifies a transform for a sample site with constrained support to unconstrained space. The transform should be invertible, and implement `log_abs_det_jacobian`. If not specified and the model has sites with constrained support, automatic transformations will be applied, as specified in :mod:`torch.distributions.constraint_registry`. :param int max_plate_nesting: Optional bound on max number of nested :func:`pyro.plate` contexts. This is required if model contains discrete sample sites that can be enumerated over in parallel. :param bool jit_compile: Optional parameter denoting whether to use the PyTorch JIT to trace the log density computation, and use this optimized executable trace in the integrator. :param dict jit_options: A dictionary contains optional arguments for :func:`torch.jit.trace` function. :param bool ignore_jit_warnings: Flag to ignore warnings from the JIT tracer when ``jit_compile=True``. Default is False. :param float target_accept_prob: Increasing this value will lead to a smaller step size, hence the sampling will be slower and more robust. Default to 0.8. :param callable init_strategy: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. :param float min_stepsize: Lower bound on stepsize in adaptation strategy. :param float max_stepsize: Upper bound on stepsize in adaptation strategy. .. note:: Internally, the mass matrix will be ordered according to the order of the names of latent variables, not the order of their appearance in the model. Example: >>> true_coefs = torch.tensor([1., 2., 3.]) >>> data = torch.randn(2000, 3) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() >>> >>> def model(data): ... coefs_mean = torch.zeros(dim) ... coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(3))) ... y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) ... return y >>> >>> hmc_kernel = HMC(model, step_size=0.0855, num_steps=4) >>> mcmc = MCMC(hmc_kernel, num_samples=500, warmup_steps=100) >>> mcmc.run(data) >>> mcmc.get_samples()['beta'].mean(0) # doctest: +SKIP tensor([ 0.9819, 1.9258, 2.9737]) """ def __init__( self, model=None, potential_fn=None, step_size=1, trajectory_length=None, num_steps=None, adapt_step_size=True, adapt_mass_matrix=True, full_mass=False, transforms=None, max_plate_nesting=None, jit_compile=False, jit_options=None, ignore_jit_warnings=False, target_accept_prob=0.8, init_strategy=init_to_uniform, *, min_stepsize: float = 1e-10, max_stepsize: float = 1e10, ): if not ((model is None) ^ (potential_fn is None)): raise ValueError("Only one of `model` or `potential_fn` must be specified.") # NB: deprecating args - model, transforms self.model = model self.transforms = transforms self._max_plate_nesting = max_plate_nesting self._jit_compile = jit_compile self._jit_options = jit_options self._ignore_jit_warnings = ignore_jit_warnings self._init_strategy = init_strategy self._min_stepsize = min_stepsize self._max_stepsize = max_stepsize self.potential_fn = potential_fn if trajectory_length is not None: self.trajectory_length = trajectory_length elif num_steps is not None: self.trajectory_length = step_size * num_steps else: self.trajectory_length = 2 * math.pi # from Stan # The following parameter is used in find_reasonable_step_size method. # In NUTS paper, this threshold is set to a fixed log(0.5). # After https://github.com/stan-dev/stan/pull/356, it is set to a fixed log(0.8). self._direction_threshold = math.log(0.8) # from Stan self._max_sliced_energy = 1000 self._reset() self._adapter = WarmupAdapter( step_size, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, target_accept_prob=target_accept_prob, dense_mass=full_mass, ) super().__init__() def _kinetic_energy(self, r_unscaled): energy = 0.0 for site_names, value in r_unscaled.items(): energy = energy + value.dot(value) return 0.5 * energy def _reset(self): self._t = 0 self._accept_cnt = 0 self._mean_accept_prob = 0.0 self._divergences = [] self._prototype_trace = None self._initial_params = None self._z_last = None self._potential_energy_last = None self._z_grads_last = None self._warmup_steps = None def _find_reasonable_step_size(self, z): step_size = self.step_size # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. try: potential_energy = self.potential_fn(z) # handle exceptions as defined in the exception registry except Exception as e: if any(h(e) for h in _EXCEPTION_HANDLERS.values()): # skip finding reasonable step size return step_size else: raise e r, r_unscaled = self._sample_r(name="r_presample_0") energy_current = self._kinetic_energy(r_unscaled) + potential_energy # This is required so as to avoid issues with autograd when model # contains transforms with cache_size > 0 (https://github.com/pyro-ppl/pyro/issues/2292) z = {k: v.clone() for k, v in z.items()} z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, step_size ) r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_new = self._kinetic_energy(r_new_unscaled) + potential_energy_new delta_energy = energy_new - energy_current # direction=1 means keep increasing step_size, otherwise decreasing step_size. # Note that the direction is -1 if delta_energy is `NaN` which may be the # case for a diverging trajectory (e.g. in the case of evaluating log prob # of a value simulated using a large step size for a constrained sample site). direction = 1 if self._direction_threshold < -delta_energy else -1 # define scale for step_size: 2 for increasing, 1/2 for decreasing step_size_scale = 2**direction direction_new = direction # keep scale step_size until accept_prob crosses its target t = 0 while ( direction_new == direction and self._min_stepsize < step_size < self._max_stepsize ): t += 1 step_size = step_size_scale * step_size r, r_unscaled = self._sample_r(name="r_presample_{}".format(t)) energy_current = self._kinetic_energy(r_unscaled) + potential_energy z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, step_size, ) r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_new = self._kinetic_energy(r_new_unscaled) + potential_energy_new delta_energy = energy_new - energy_current direction_new = 1 if self._direction_threshold < -delta_energy else -1 step_size = max(step_size, self._min_stepsize) step_size = min(step_size, self._max_stepsize) return step_size def _sample_r(self, name): r_unscaled = {} options = { "dtype": self._potential_energy_last.dtype, "device": self._potential_energy_last.device, } for site_names, size in self.mass_matrix_adapter.mass_matrix_size.items(): # we want to sample from Normal distribution using `sample` method rather than # `rsample` method because the former is a bit faster r_unscaled[site_names] = pyro.sample( "{}_{}".format(name, site_names), NonreparameterizedNormal( torch.zeros(size, **options), torch.ones(size, **options) ), ) r = self.mass_matrix_adapter.scale(r_unscaled, r_prototype=self.initial_params) return r, r_unscaled @property def mass_matrix_adapter(self): return self._adapter.mass_matrix_adapter @mass_matrix_adapter.setter def mass_matrix_adapter(self, value): self._adapter.mass_matrix_adapter = value @property def inverse_mass_matrix(self): return self.mass_matrix_adapter.inverse_mass_matrix @property def step_size(self): return self._adapter.step_size @property def num_steps(self): return max(1, int(self.trajectory_length / self.step_size)) @property def initial_params(self): return self._initial_params @initial_params.setter def initial_params(self, params): self._initial_params = params def _initialize_model_properties(self, model_args, model_kwargs): init_params, potential_fn, transforms, trace = initialize_model( self.model, model_args, model_kwargs, transforms=self.transforms, max_plate_nesting=self._max_plate_nesting, jit_compile=self._jit_compile, jit_options=self._jit_options, skip_jit_warnings=self._ignore_jit_warnings, init_strategy=self._init_strategy, initial_params=self._initial_params, ) self.potential_fn = potential_fn self.transforms = transforms self._initial_params = init_params self._prototype_trace = trace def _initialize_adapter(self): if self._adapter.dense_mass is False: dense_sites_list = [] elif self._adapter.dense_mass is True: dense_sites_list = [tuple(sorted(self.initial_params))] else: msg = "full_mass should be a list of tuples of site names." dense_sites_list = self._adapter.dense_mass assert isinstance(dense_sites_list, list), msg for dense_sites in dense_sites_list: assert dense_sites and isinstance(dense_sites, tuple), msg for name in dense_sites: assert isinstance(name, str) and name in self.initial_params, msg dense_sites_set = set().union(*dense_sites_list) diag_sites = tuple( sorted( [name for name in self.initial_params if name not in dense_sites_set] ) ) assert len(diag_sites) + sum([len(sites) for sites in dense_sites_list]) == len( self.initial_params ), "Site names specified in full_mass are duplicated." mass_matrix_shape = OrderedDict() for dense_sites in dense_sites_list: size = sum([self.initial_params[site].numel() for site in dense_sites]) mass_matrix_shape[dense_sites] = (size, size) if diag_sites: size = sum([self.initial_params[site].numel() for site in diag_sites]) mass_matrix_shape[diag_sites] = (size,) options = { "dtype": self._potential_energy_last.dtype, "device": self._potential_energy_last.device, } self._adapter.configure( self._warmup_steps, mass_matrix_shape=mass_matrix_shape, find_reasonable_step_size_fn=self._find_reasonable_step_size, options=options, ) if self._adapter.adapt_step_size: self._adapter.reset_step_size_adaptation(self._initial_params)
[docs] def setup(self, warmup_steps, *args, **kwargs): self._warmup_steps = warmup_steps if self.model is not None: self._initialize_model_properties(args, kwargs) if self.initial_params: z = {k: v.detach() for k, v in self.initial_params.items()} z_grads, potential_energy = potential_grad(self.potential_fn, z) else: z_grads, potential_energy = {}, self.potential_fn(self.initial_params) self._cache(self.initial_params, potential_energy, z_grads) if self.initial_params: self._initialize_adapter()
[docs] def cleanup(self): self._reset()
def _cache(self, z, potential_energy, z_grads=None): self._z_last = z self._potential_energy_last = potential_energy self._z_grads_last = z_grads
[docs] def clear_cache(self): self._z_last = None self._potential_energy_last = None self._z_grads_last = None
def _fetch_from_cache(self): return self._z_last, self._potential_energy_last, self._z_grads_last
[docs] def sample(self, params): z, potential_energy, z_grads = self._fetch_from_cache() # recompute PE when cache is cleared if z is None: z = params z_grads, potential_energy = potential_grad(self.potential_fn, z) self._cache(z, potential_energy, z_grads) # return early if no sample sites elif len(z) == 0: self._t += 1 self._mean_accept_prob = 1.0 if self._t > self._warmup_steps: self._accept_cnt += 1 return params r, r_unscaled = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r_unscaled) + potential_energy # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, self.step_size, self.num_steps, z_grads=z_grads, ) # apply Metropolis correction. r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_proposal = ( self._kinetic_energy(r_new_unscaled) + potential_energy_new ) delta_energy = energy_proposal - energy_current # handle the NaN case which may be the case for a diverging trajectory # when using a large step size. delta_energy = ( scalar_like(delta_energy, float("inf")) if torch_isnan(delta_energy) else delta_energy ) if delta_energy > self._max_sliced_energy and self._t >= self._warmup_steps: self._divergences.append(self._t - self._warmup_steps) accept_prob = (-delta_energy).exp().clamp(max=1.0) rand = pyro.sample( "rand_t={}".format(self._t), dist.Uniform(scalar_like(accept_prob, 0.0), scalar_like(accept_prob, 1.0)), ) accepted = False if rand < accept_prob: accepted = True z = z_new z_grads = z_grads_new self._cache(z, potential_energy_new, z_grads) self._t += 1 if self._t > self._warmup_steps: n = self._t - self._warmup_steps if accepted: self._accept_cnt += 1 else: n = self._t self._adapter.step(self._t, z, accept_prob, z_grads) self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n return z.copy()
[docs] def logging(self): return OrderedDict( [ ("step size", "{:.2e}".format(self.step_size)), ("acc. prob", "{:.3f}".format(self._mean_accept_prob)), ] )
[docs] def diagnostics(self): return { "divergences": self._divergences, "acceptance rate": self._accept_cnt / (self._t - self._warmup_steps), }