Source code for pyro.infer.mcmc.nuts

from collections import namedtuple

import torch

import pyro
import pyro.distributions as dist
from pyro.distributions.util import scalar_like
from pyro.infer.mcmc.hmc import HMC
from pyro.ops.integrator import velocity_verlet
from pyro.util import optional, torch_isnan


def _logaddexp(x, y):
    minval, maxval = (x, y) if x < y else (y, x)
    return (minval - maxval).exp().log1p() + maxval


# sum_accept_probs and num_proposals are used to calculate
# the statistic accept_prob for Dual Averaging scheme;
# z_left_grads and z_right_grads are kept to avoid recalculating
# grads at left and right leaves;
# r_sum is used to check turning condition;
# z_proposal_pe and z_proposal_grads are used to cache the
#   potential energy and potential energy gradient values for
#   the proposal trace.
# weight is the number of valid points in case we use slice sampling
#   and is the log sum of (unnormalized) probabilites of valid points
#   when we use multinomial sampling
_TreeInfo = namedtuple("TreeInfo", ["z_left", "r_left", "z_left_grads",
                                    "z_right", "r_right", "z_right_grads",
                                    "z_proposal", "z_proposal_pe", "z_proposal_grads",
                                    "r_sum", "weight", "turning", "diverging",
                                    "sum_accept_probs", "num_proposals"])


[docs]class NUTS(HMC): """ No-U-Turn Sampler kernel, which provides an efficient and convenient way to run Hamiltonian Monte Carlo. The number of steps taken by the integrator is dynamically adjusted on each call to ``sample`` to ensure an optimal length for the Hamiltonian trajectory [1]. As such, the samples generated will typically have lower autocorrelation than those generated by the :class:`~pyro.infer.mcmc.HMC` kernel. Optionally, the NUTS kernel also provides the ability to adapt step size during the warmup phase. Refer to the `baseball example <https://github.com/uber/pyro/blob/dev/examples/baseball.py>`_ to see how to do Bayesian inference in Pyro using NUTS. **References** [1] `The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo`, Matthew D. Hoffman, and Andrew Gelman. [2] `A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt [3] `Slice Sampling`, Radford M. Neal :param model: Python callable containing Pyro primitives. :param potential_fn: Python callable calculating potential energy with input is a dict of real support parameters. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool full_mass: A flag to decide if mass matrix is dense or diagonal. :param bool use_multinomial_sampling: A flag to decide if we want to sample candidates along its trajectory using "multinomial sampling" or using "slice sampling". Slice sampling is used in the original NUTS paper [1], while multinomial sampling is suggested in [2]. By default, this flag is set to True. If it is set to `False`, NUTS uses slice sampling. :param dict transforms: Optional dictionary that specifies a transform for a sample site with constrained support to unconstrained space. The transform should be invertible, and implement `log_abs_det_jacobian`. If not specified and the model has sites with constrained support, automatic transformations will be applied, as specified in :mod:`torch.distributions.constraint_registry`. :param int max_plate_nesting: Optional bound on max number of nested :func:`pyro.plate` contexts. This is required if model contains discrete sample sites that can be enumerated over in parallel. :param bool jit_compile: Optional parameter denoting whether to use the PyTorch JIT to trace the log density computation, and use this optimized executable trace in the integrator. :param dict jit_options: A dictionary contains optional arguments for :func:`torch.jit.trace` function. :param bool ignore_jit_warnings: Flag to ignore warnings from the JIT tracer when ``jit_compile=True``. Default is False. :param float target_accept_prob: Target acceptance probability of step size adaptation scheme. Increasing this value will lead to a smaller step size, so the sampling will be slower but more robust. Default to 0.8. :param int max_tree_depth: Max depth of the binary tree created during the doubling scheme of NUTS sampler. Default to 10. Example: >>> true_coefs = torch.tensor([1., 2., 3.]) >>> data = torch.randn(2000, 3) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() >>> >>> def model(data): ... coefs_mean = torch.zeros(dim) ... coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(3))) ... y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) ... return y >>> >>> nuts_kernel = NUTS(model, adapt_step_size=True) >>> mcmc = MCMC(nuts_kernel, num_samples=500, warmup_steps=300) >>> mcmc.run(data) >>> mcmc.get_samples()['beta'].mean(0) # doctest: +SKIP tensor([ 0.9221, 1.9464, 2.9228]) """ def __init__(self, model=None, potential_fn=None, step_size=1, adapt_step_size=True, adapt_mass_matrix=True, full_mass=False, use_multinomial_sampling=True, transforms=None, max_plate_nesting=None, jit_compile=False, jit_options=None, ignore_jit_warnings=False, target_accept_prob=0.8, max_tree_depth=10): super(NUTS, self).__init__(model, potential_fn, step_size, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, full_mass=full_mass, transforms=transforms, max_plate_nesting=max_plate_nesting, jit_compile=jit_compile, jit_options=jit_options, ignore_jit_warnings=ignore_jit_warnings, target_accept_prob=target_accept_prob) self.use_multinomial_sampling = use_multinomial_sampling self._max_tree_depth = max_tree_depth # There are three conditions to stop doubling process: # + Tree is becoming too big. # + The trajectory is making a U-turn. # + The probability of the states becoming negligible: p(z, r) << u, # here u is the "slice" variable introduced at the `self.sample(...)` method. # Denote E_p = -log p(z, r), E_u = -log u, the third condition is equivalent to # sliced_energy := E_p - E_u > some constant =: max_sliced_energy. # This also suggests the notion "diverging" in the implemenation: # when the energy E_p diverges from E_u too much, we stop doubling. # Here, as suggested in [1], we set dE_max = 1000. self._max_sliced_energy = 1000 def _is_turning(self, r_left, r_right, r_sum): # We follow the strategy in Section A.4.2 of [2] for this implementation. r_left_flat = torch.cat([r_left[site_name].reshape(-1) for site_name in sorted(r_left)]) r_right_flat = torch.cat([r_right[site_name].reshape(-1) for site_name in sorted(r_right)]) r_sum = r_sum - (r_left_flat + r_right_flat) / 2 if self.inverse_mass_matrix.dim() == 2: if (self.inverse_mass_matrix.matmul(r_left_flat).dot(r_sum) > 0 and self.inverse_mass_matrix.matmul(r_right_flat).dot(r_sum) > 0): return False else: if (self.inverse_mass_matrix.mul(r_left_flat).dot(r_sum) > 0 and self.inverse_mass_matrix.mul(r_right_flat).dot(r_sum) > 0): return False return True def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): step_size = self.step_size if direction == 1 else -self.step_size z_new, r_new, z_grads, potential_energy = velocity_verlet( z, r, self.potential_fn, self.inverse_mass_matrix, step_size, z_grads=z_grads) r_new_flat = torch.cat([r_new[site_name].reshape(-1) for site_name in sorted(r_new)]) energy_new = potential_energy + self._kinetic_energy(r_new) # handle the NaN case energy_new = scalar_like(energy_new, float("inf")) if torch_isnan(energy_new) else energy_new sliced_energy = energy_new + log_slice diverging = (sliced_energy > self._max_sliced_energy) delta_energy = energy_new - energy_current accept_prob = (-delta_energy).exp().clamp(max=1.0) if self.use_multinomial_sampling: tree_weight = -sliced_energy else: # As a part of the slice sampling process (see below), along the trajectory # we eliminate states which p(z, r) < u, or dE > 0. # Due to this elimination (and stop doubling conditions), # the weight of binary tree might not equal to 2^tree_depth. tree_weight = scalar_like(sliced_energy, 1. if sliced_energy <= 0 else 0.) return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, z_new, potential_energy, z_grads, r_new_flat, tree_weight, False, diverging, accept_prob, 1) def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_current): if tree_depth == 0: return self._build_basetree(z, r, z_grads, log_slice, direction, energy_current) # build the first half of tree half_tree = self._build_tree(z, r, z_grads, log_slice, direction, tree_depth-1, energy_current) z_proposal = half_tree.z_proposal z_proposal_pe = half_tree.z_proposal_pe z_proposal_grads = half_tree.z_proposal_grads # Check conditions to stop doubling. If we meet that condition, # there is no need to build the other tree. if half_tree.turning or half_tree.diverging: return half_tree # Else, build remaining half of tree. # If we are going to the right, start from the right leaf of the first half. if direction == 1: z = half_tree.z_right r = half_tree.r_right z_grads = half_tree.z_right_grads else: # otherwise, start from the left leaf of the first half z = half_tree.z_left r = half_tree.r_left z_grads = half_tree.z_left_grads other_half_tree = self._build_tree(z, r, z_grads, log_slice, direction, tree_depth-1, energy_current) if self.use_multinomial_sampling: tree_weight = _logaddexp(half_tree.weight, other_half_tree.weight) else: tree_weight = half_tree.weight + other_half_tree.weight sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs num_proposals = half_tree.num_proposals + other_half_tree.num_proposals r_sum = half_tree.r_sum + other_half_tree.r_sum # The probability of that proposal belongs to which half of tree # is computed based on the weights of each half. if self.use_multinomial_sampling: other_half_tree_prob = (other_half_tree.weight - tree_weight).exp() else: # For the special case that the weights of each half are both 0, # we choose the proposal from the first half # (any is fine, because the probability of picking it at the end is 0!). other_half_tree_prob = (other_half_tree.weight / tree_weight if tree_weight > 0 else scalar_like(tree_weight, 0.)) is_other_half_tree = pyro.sample("is_other_half_tree", dist.Bernoulli(probs=other_half_tree_prob)) if is_other_half_tree == 1: z_proposal = other_half_tree.z_proposal z_proposal_pe = other_half_tree.z_proposal_pe z_proposal_grads = other_half_tree.z_proposal_grads # leaves of the full tree are determined by the direction if direction == 1: z_left = half_tree.z_left r_left = half_tree.r_left z_left_grads = half_tree.z_left_grads z_right = other_half_tree.z_right r_right = other_half_tree.r_right z_right_grads = other_half_tree.z_right_grads else: z_left = other_half_tree.z_left r_left = other_half_tree.r_left z_left_grads = other_half_tree.z_left_grads z_right = half_tree.z_right r_right = half_tree.r_right z_right_grads = half_tree.z_right_grads # We already check if first half tree is turning. Now, we check # if the other half tree or full tree are turning. turning = other_half_tree.turning or self._is_turning(r_left, r_right, r_sum) # The divergence is checked by the second half tree (the first half is already checked). diverging = other_half_tree.diverging return _TreeInfo(z_left, r_left, z_left_grads, z_right, r_right, z_right_grads, z_proposal, z_proposal_pe, z_proposal_grads, r_sum, tree_weight, turning, diverging, sum_accept_probs, num_proposals)
[docs] def sample(self, params): z, potential_energy, z_grads = self._fetch_from_cache() # recompute PE when cache is cleared if z is None: z = params potential_energy = self.potential_fn(z) self._cache(z, potential_energy) # return early if no sample sites elif len(z) == 0: self._t += 1 self._mean_accept_prob = 1. if self._t > self._warmup_steps: self._accept_cnt += 1 return z r, r_flat = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r) + potential_energy # Ideally, following a symplectic integrator trajectory, the energy is constant. # In that case, we can sample the proposal uniformly, and there is no need to use "slice". # However, it is not the case for real situation: there are errors during the computation. # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted # by u). # The sampling process goes as follows: # first sampling u from initial state (z_0, r_0) according to # u ~ Uniform(0, p(z_0, r_0)), # then sampling state (z, r) from the integrator trajectory according to # (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}). # # For more information about slice sampling method, see [3]. # For another version of NUTS which uses multinomial sampling instead of slice sampling, # see [2]. if self.use_multinomial_sampling: log_slice = -energy_current else: # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can # sample log_slice directly using `energy`, so as to avoid potential underflow or # overflow issues ([2]). slice_exp_term = pyro.sample("slicevar_exp_t={}".format(self._t), dist.Exponential(scalar_like(energy_current, 1.))) log_slice = -energy_current - slice_exp_term z_left = z_right = z r_left = r_right = r z_left_grads = z_right_grads = z_grads accepted = False r_sum = r_flat sum_accept_probs = 0. num_proposals = 0 tree_weight = scalar_like(energy_current, 0. if self.use_multinomial_sampling else 1.) # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation. with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): # doubling process, stop when turning or diverging tree_depth = 0 while tree_depth < self._max_tree_depth: direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth), dist.Bernoulli(probs=scalar_like(tree_weight, 0.5))) direction = int(direction.item()) if direction == 1: # go to the right, start from the right leaf of current tree new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice, direction, tree_depth, energy_current) # update leaf for the next doubling process z_right = new_tree.z_right r_right = new_tree.r_right z_right_grads = new_tree.z_right_grads else: # go the the left, start from the left leaf of current tree new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice, direction, tree_depth, energy_current) z_left = new_tree.z_left r_left = new_tree.r_left z_left_grads = new_tree.z_left_grads sum_accept_probs = sum_accept_probs + new_tree.sum_accept_probs num_proposals = num_proposals + new_tree.num_proposals # stop doubling if new_tree.diverging: if self._t >= self._warmup_steps: self._divergences.append(self._t - self._warmup_steps) break if new_tree.turning: break tree_depth += 1 if self.use_multinomial_sampling: new_tree_prob = (new_tree.weight - tree_weight).exp() else: new_tree_prob = new_tree.weight / tree_weight rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth), dist.Uniform(scalar_like(new_tree_prob, 0.), scalar_like(new_tree_prob, 1.))) if rand < new_tree_prob: accepted = True z = new_tree.z_proposal self._cache(z, new_tree.z_proposal_pe, new_tree.z_proposal_grads) r_sum = r_sum + new_tree.r_sum if self._is_turning(r_left, r_right, r_sum): # stop doubling break else: # update tree_weight if self.use_multinomial_sampling: tree_weight = _logaddexp(tree_weight, new_tree.weight) else: tree_weight = tree_weight + new_tree.weight accept_prob = sum_accept_probs / num_proposals self._t += 1 if self._t > self._warmup_steps: n = self._t - self._warmup_steps if accepted: self._accept_cnt += 1 else: n = self._t self._adapter.step(self._t, z, accept_prob) self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n return z.copy()