Source code for pyro.infer.mcmc.nuts

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple

import pyro
import pyro.distributions as dist
from pyro.distributions.util import scalar_like
from pyro.infer.autoguide import init_to_uniform
from pyro.infer.mcmc.hmc import HMC
from pyro.ops.integrator import potential_grad, velocity_verlet
from pyro.util import optional, torch_isnan


def _logaddexp(x, y):
    minval, maxval = (x, y) if x < y else (y, x)
    return (minval - maxval).exp().log1p() + maxval


# sum_accept_probs and num_proposals are used to calculate
# the statistic accept_prob for Dual Averaging scheme;
# z_left_grads and z_right_grads are kept to avoid recalculating
# grads at left and right leaves;
# r_sum is used to check turning condition;
# z_proposal_pe and z_proposal_grads are used to cache the
#   potential energy and potential energy gradient values for
#   the proposal trace.
# weight is the number of valid points in case we use slice sampling
#   and is the log sum of (unnormalized) probabilites of valid points
#   when we use multinomial sampling
_TreeInfo = namedtuple(
    "TreeInfo",
    [
        "z_left",
        "r_left",
        "r_left_unscaled",
        "z_left_grads",
        "z_right",
        "r_right",
        "r_right_unscaled",
        "z_right_grads",
        "z_proposal",
        "z_proposal_pe",
        "z_proposal_grads",
        "r_sum",
        "weight",
        "turning",
        "diverging",
        "sum_accept_probs",
        "num_proposals",
    ],
)


[docs]class NUTS(HMC): """ No-U-Turn Sampler kernel, which provides an efficient and convenient way to run Hamiltonian Monte Carlo. The number of steps taken by the integrator is dynamically adjusted on each call to ``sample`` to ensure an optimal length for the Hamiltonian trajectory [1]. As such, the samples generated will typically have lower autocorrelation than those generated by the :class:`~pyro.infer.mcmc.HMC` kernel. Optionally, the NUTS kernel also provides the ability to adapt step size during the warmup phase. Refer to the `baseball example <https://github.com/pyro-ppl/pyro/blob/dev/examples/baseball.py>`_ to see how to do Bayesian inference in Pyro using NUTS. **References** [1] `The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo`, Matthew D. Hoffman, and Andrew Gelman. [2] `A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt [3] `Slice Sampling`, Radford M. Neal :param model: Python callable containing Pyro primitives. :param potential_fn: Python callable calculating potential energy with input is a dict of real support parameters. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool full_mass: A flag to decide if mass matrix is dense or diagonal. :param bool use_multinomial_sampling: A flag to decide if we want to sample candidates along its trajectory using "multinomial sampling" or using "slice sampling". Slice sampling is used in the original NUTS paper [1], while multinomial sampling is suggested in [2]. By default, this flag is set to True. If it is set to `False`, NUTS uses slice sampling. :param dict transforms: Optional dictionary that specifies a transform for a sample site with constrained support to unconstrained space. The transform should be invertible, and implement `log_abs_det_jacobian`. If not specified and the model has sites with constrained support, automatic transformations will be applied, as specified in :mod:`torch.distributions.constraint_registry`. :param int max_plate_nesting: Optional bound on max number of nested :func:`pyro.plate` contexts. This is required if model contains discrete sample sites that can be enumerated over in parallel. :param bool jit_compile: Optional parameter denoting whether to use the PyTorch JIT to trace the log density computation, and use this optimized executable trace in the integrator. :param dict jit_options: A dictionary contains optional arguments for :func:`torch.jit.trace` function. :param bool ignore_jit_warnings: Flag to ignore warnings from the JIT tracer when ``jit_compile=True``. Default is False. :param float target_accept_prob: Target acceptance probability of step size adaptation scheme. Increasing this value will lead to a smaller step size, so the sampling will be slower but more robust. Default to 0.8. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Default to 10. :param callable init_strategy: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. Example: >>> true_coefs = torch.tensor([1., 2., 3.]) >>> data = torch.randn(2000, 3) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() >>> >>> def model(data): ... coefs_mean = torch.zeros(dim) ... coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(3))) ... y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) ... return y >>> >>> nuts_kernel = NUTS(model, adapt_step_size=True) >>> mcmc = MCMC(nuts_kernel, num_samples=500, warmup_steps=300) >>> mcmc.run(data) >>> mcmc.get_samples()['beta'].mean(0) # doctest: +SKIP tensor([ 0.9221, 1.9464, 2.9228]) """ def __init__( self, model=None, potential_fn=None, step_size=1, adapt_step_size=True, adapt_mass_matrix=True, full_mass=False, use_multinomial_sampling=True, transforms=None, max_plate_nesting=None, jit_compile=False, jit_options=None, ignore_jit_warnings=False, target_accept_prob=0.8, max_tree_depth=10, init_strategy=init_to_uniform, ): super().__init__( model, potential_fn, step_size, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, full_mass=full_mass, transforms=transforms, max_plate_nesting=max_plate_nesting, jit_compile=jit_compile, jit_options=jit_options, ignore_jit_warnings=ignore_jit_warnings, target_accept_prob=target_accept_prob, init_strategy=init_strategy, ) self.use_multinomial_sampling = use_multinomial_sampling self._max_tree_depth = max_tree_depth # There are three conditions to stop doubling process: # + Tree is becoming too big. # + The trajectory is making a U-turn. # + The probability of the states becoming negligible: p(z, r) << u, # here u is the "slice" variable introduced at the `self.sample(...)` method. # Denote E_p = -log p(z, r), E_u = -log u, the third condition is equivalent to # sliced_energy := E_p - E_u > some constant =: max_sliced_energy. # This also suggests the notion "diverging" in the implemenation: # when the energy E_p diverges from E_u too much, we stop doubling. # Here, as suggested in [1], we set dE_max = 1000. self._max_sliced_energy = 1000 def _is_turning(self, r_left_unscaled, r_right_unscaled, r_sum): # We follow the strategy in Section A.4.2 of [2] for this implementation. left_angle = 0.0 right_angle = 0.0 for site_names, value in r_sum.items(): rho = ( value - (r_left_unscaled[site_names] + r_right_unscaled[site_names]) / 2 ) left_angle += r_left_unscaled[site_names].dot(rho) right_angle += r_right_unscaled[site_names].dot(rho) return (left_angle <= 0) or (right_angle <= 0) def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): step_size = self.step_size if direction == 1 else -self.step_size z_new, r_new, z_grads, potential_energy = velocity_verlet( z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, step_size, z_grads=z_grads, ) r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_new = potential_energy + self._kinetic_energy(r_new_unscaled) # handle the NaN case energy_new = ( scalar_like(energy_new, float("inf")) if torch_isnan(energy_new) else energy_new ) sliced_energy = energy_new + log_slice diverging = sliced_energy > self._max_sliced_energy delta_energy = energy_new - energy_current accept_prob = (-delta_energy).exp().clamp(max=1.0) if self.use_multinomial_sampling: tree_weight = -sliced_energy else: # As a part of the slice sampling process (see below), along the trajectory # we eliminate states which p(z, r) < u, or dE > 0. # Due to this elimination (and stop doubling conditions), # the weight of binary tree might not equal to 2^tree_depth. tree_weight = scalar_like(sliced_energy, 1.0 if sliced_energy <= 0 else 0.0) r_sum = r_new_unscaled return _TreeInfo( z_new, r_new, r_new_unscaled, z_grads, z_new, r_new, r_new_unscaled, z_grads, z_new, potential_energy, z_grads, r_sum, tree_weight, False, diverging, accept_prob, 1, ) def _build_tree( self, z, r, z_grads, log_slice, direction, tree_depth, energy_current ): if tree_depth == 0: return self._build_basetree( z, r, z_grads, log_slice, direction, energy_current ) # build the first half of tree half_tree = self._build_tree( z, r, z_grads, log_slice, direction, tree_depth - 1, energy_current ) z_proposal = half_tree.z_proposal z_proposal_pe = half_tree.z_proposal_pe z_proposal_grads = half_tree.z_proposal_grads # Check conditions to stop doubling. If we meet that condition, # there is no need to build the other tree. if half_tree.turning or half_tree.diverging: return half_tree # Else, build remaining half of tree. # If we are going to the right, start from the right leaf of the first half. if direction == 1: z = half_tree.z_right r = half_tree.r_right z_grads = half_tree.z_right_grads else: # otherwise, start from the left leaf of the first half z = half_tree.z_left r = half_tree.r_left z_grads = half_tree.z_left_grads other_half_tree = self._build_tree( z, r, z_grads, log_slice, direction, tree_depth - 1, energy_current ) if self.use_multinomial_sampling: tree_weight = _logaddexp(half_tree.weight, other_half_tree.weight) else: tree_weight = half_tree.weight + other_half_tree.weight sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs num_proposals = half_tree.num_proposals + other_half_tree.num_proposals r_sum = { site_names: half_tree.r_sum[site_names] + other_half_tree.r_sum[site_names] for site_names in self.inverse_mass_matrix } # The probability of that proposal belongs to which half of tree # is computed based on the weights of each half. if self.use_multinomial_sampling: other_half_tree_prob = (other_half_tree.weight - tree_weight).exp() else: # For the special case that the weights of each half are both 0, # we choose the proposal from the first half # (any is fine, because the probability of picking it at the end is 0!). other_half_tree_prob = ( other_half_tree.weight / tree_weight if tree_weight > 0 else scalar_like(tree_weight, 0.0) ) is_other_half_tree = pyro.sample( "is_other_half_tree", dist.Bernoulli(probs=other_half_tree_prob) ) if is_other_half_tree == 1: z_proposal = other_half_tree.z_proposal z_proposal_pe = other_half_tree.z_proposal_pe z_proposal_grads = other_half_tree.z_proposal_grads # leaves of the full tree are determined by the direction if direction == 1: z_left = half_tree.z_left r_left = half_tree.r_left r_left_unscaled = half_tree.r_left_unscaled z_left_grads = half_tree.z_left_grads z_right = other_half_tree.z_right r_right = other_half_tree.r_right r_right_unscaled = other_half_tree.r_right_unscaled z_right_grads = other_half_tree.z_right_grads else: z_left = other_half_tree.z_left r_left = other_half_tree.r_left r_left_unscaled = other_half_tree.r_left_unscaled z_left_grads = other_half_tree.z_left_grads z_right = half_tree.z_right r_right = half_tree.r_right r_right_unscaled = half_tree.r_right_unscaled z_right_grads = half_tree.z_right_grads # We already check if first half tree is turning. Now, we check # if the other half tree or full tree are turning. turning = other_half_tree.turning or self._is_turning( r_left_unscaled, r_right_unscaled, r_sum ) # The divergence is checked by the second half tree (the first half is already checked). diverging = other_half_tree.diverging return _TreeInfo( z_left, r_left, r_left_unscaled, z_left_grads, z_right, r_right, r_right_unscaled, z_right_grads, z_proposal, z_proposal_pe, z_proposal_grads, r_sum, tree_weight, turning, diverging, sum_accept_probs, num_proposals, )
[docs] def sample(self, params): z, potential_energy, z_grads = self._fetch_from_cache() # recompute PE when cache is cleared if z is None: z = params z_grads, potential_energy = potential_grad(self.potential_fn, z) self._cache(z, potential_energy, z_grads) # return early if no sample sites elif len(z) == 0: self._t += 1 self._mean_accept_prob = 1.0 if self._t > self._warmup_steps: self._accept_cnt += 1 return z r, r_unscaled = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r_unscaled) + potential_energy # Ideally, following a symplectic integrator trajectory, the energy is constant. # In that case, we can sample the proposal uniformly, and there is no need to use "slice". # However, it is not the case for real situation: there are errors during the computation. # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted # by u). # The sampling process goes as follows: # first sampling u from initial state (z_0, r_0) according to # u ~ Uniform(0, p(z_0, r_0)), # then sampling state (z, r) from the integrator trajectory according to # (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}). # # For more information about slice sampling method, see [3]. # For another version of NUTS which uses multinomial sampling instead of slice sampling, # see [2]. if self.use_multinomial_sampling: log_slice = -energy_current else: # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can # sample log_slice directly using `energy`, so as to avoid potential underflow or # overflow issues ([2]). slice_exp_term = pyro.sample( "slicevar_exp_t={}".format(self._t), dist.Exponential(scalar_like(energy_current, 1.0)), ) log_slice = -energy_current - slice_exp_term z_left = z_right = z r_left = r_right = r r_left_unscaled = r_right_unscaled = r_unscaled z_left_grads = z_right_grads = z_grads accepted = False r_sum = r_unscaled sum_accept_probs = 0.0 num_proposals = 0 tree_weight = scalar_like( energy_current, 0.0 if self.use_multinomial_sampling else 1.0 ) # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation. with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): # doubling process, stop when turning or diverging tree_depth = 0 while tree_depth < self._max_tree_depth: direction = pyro.sample( "direction_t={}_treedepth={}".format(self._t, tree_depth), dist.Bernoulli(probs=scalar_like(tree_weight, 0.5)), ) direction = int(direction.item()) if ( direction == 1 ): # go to the right, start from the right leaf of current tree new_tree = self._build_tree( z_right, r_right, z_right_grads, log_slice, direction, tree_depth, energy_current, ) # update leaf for the next doubling process z_right = new_tree.z_right r_right = new_tree.r_right r_right_unscaled = new_tree.r_right_unscaled z_right_grads = new_tree.z_right_grads else: # go the the left, start from the left leaf of current tree new_tree = self._build_tree( z_left, r_left, z_left_grads, log_slice, direction, tree_depth, energy_current, ) z_left = new_tree.z_left r_left = new_tree.r_left r_left_unscaled = new_tree.r_left_unscaled z_left_grads = new_tree.z_left_grads sum_accept_probs = sum_accept_probs + new_tree.sum_accept_probs num_proposals = num_proposals + new_tree.num_proposals # stop doubling if new_tree.diverging: if self._t >= self._warmup_steps: self._divergences.append(self._t - self._warmup_steps) break if new_tree.turning: break tree_depth += 1 if self.use_multinomial_sampling: new_tree_prob = (new_tree.weight - tree_weight).exp() else: new_tree_prob = new_tree.weight / tree_weight rand = pyro.sample( "rand_t={}_treedepth={}".format(self._t, tree_depth), dist.Uniform( scalar_like(new_tree_prob, 0.0), scalar_like(new_tree_prob, 1.0) ), ) if rand < new_tree_prob: accepted = True z = new_tree.z_proposal z_grads = new_tree.z_proposal_grads self._cache(z, new_tree.z_proposal_pe, z_grads) r_sum = { site_names: r_sum[site_names] + new_tree.r_sum[site_names] for site_names in r_unscaled } if self._is_turning( r_left_unscaled, r_right_unscaled, r_sum ): # stop doubling break else: # update tree_weight if self.use_multinomial_sampling: tree_weight = _logaddexp(tree_weight, new_tree.weight) else: tree_weight = tree_weight + new_tree.weight accept_prob = sum_accept_probs / num_proposals self._t += 1 if self._t > self._warmup_steps: n = self._t - self._warmup_steps if accepted: self._accept_cnt += 1 else: n = self._t self._adapter.step(self._t, z, accept_prob, z_grads) self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n return z.copy()