Source code for pyro.infer.mcmc.nuts

from __future__ import absolute_import, division, print_function

from collections import namedtuple

import torch

import pyro
import pyro.distributions as dist
from pyro.distributions.util import logsumexp
from pyro.infer.mcmc.hmc import HMC
from pyro.ops.integrator import velocity_verlet
from pyro.util import optional, torch_isnan

# sum_accept_probs and num_proposals are used to calculate
# the statistic accept_prob for Dual Averaging scheme;
# z_left_grads and z_right_grads are kept to avoid recalculating
# grads at left and right leaves;
# r_sum is used to check turning condition;
# z_proposal_pe and z_proposal_grads are used to cache the
#   potential energy and potential energy gradient values for
#   the proposal trace.
# weight is the number of valid points in case we use slice sampling
#   and is the log sum of (unnormalized) probabilites of valid points
#   when we use multinomial sampling
_TreeInfo = namedtuple("TreeInfo", ["z_left", "r_left", "z_left_grads",
                                    "z_right", "r_right", "z_right_grads",
                                    "z_proposal", "z_proposal_pe", "z_proposal_grads",
                                    "r_sum", "weight", "turning", "diverging",
                                    "sum_accept_probs", "num_proposals"])

[docs]class NUTS(HMC): """ No-U-Turn Sampler kernel, which provides an efficient and convenient way to run Hamiltonian Monte Carlo. The number of steps taken by the integrator is dynamically adjusted on each call to ``sample`` to ensure an optimal length for the Hamiltonian trajectory [1]. As such, the samples generated will typically have lower autocorrelation than those generated by the :class:`~pyro.infer.mcmc.HMC` kernel. Optionally, the NUTS kernel also provides the ability to adapt step size during the warmup phase. Refer to the `baseball example <>`_ to see how to do Bayesian inference in Pyro using NUTS. **References** [1] `The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo`, Matthew D. Hoffman, and Andrew Gelman. [2] `A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt [3] `Slice Sampling`, Radford M. Neal :param model: Python callable containing Pyro primitives. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool full_mass: A flag to decide if mass matrix is dense or diagonal. :param bool use_multinomial_sampling: A flag to decide if we want to sample candidates along its trajectory using "multinomial sampling" or using "slice sampling". Slice sampling is used in the original NUTS paper [1], while multinomial sampling is suggested in [2]. By default, this flag is set to True. If it is set to `False`, NUTS uses slice sampling. :param dict transforms: Optional dictionary that specifies a transform for a sample site with constrained support to unconstrained space. The transform should be invertible, and implement `log_abs_det_jacobian`. If not specified and the model has sites with constrained support, automatic transformations will be applied, as specified in :mod:`torch.distributions.constraint_registry`. :param int max_plate_nesting: Optional bound on max number of nested :func:`pyro.plate` contexts. This is required if model contains discrete sample sites that can be enumerated over in parallel. :param bool jit_compile: Optional parameter denoting whether to use the PyTorch JIT to trace the log density computation, and use this optimized executable trace in the integrator. Example: >>> true_coefs = torch.tensor([1., 2., 3.]) >>> data = torch.randn(2000, 3) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() >>> >>> def model(data): ... coefs_mean = torch.zeros(dim) ... coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(3))) ... y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) ... return y >>> >>> nuts_kernel = NUTS(model, adapt_step_size=True) >>> mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=300).run(data) >>> posterior = mcmc_run.marginal('beta').empirical['beta'] >>> posterior.mean # doctest: +SKIP tensor([ 0.9221, 1.9464, 2.9228]) """ def __init__(self, model, step_size=1, adapt_step_size=True, adapt_mass_matrix=True, full_mass=False, use_multinomial_sampling=True, transforms=None, max_plate_nesting=None, jit_compile=False, ignore_jit_warnings=False): super(NUTS, self).__init__(model, step_size, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, full_mass=full_mass, transforms=transforms, max_plate_nesting=max_plate_nesting, jit_compile=jit_compile, ignore_jit_warnings=ignore_jit_warnings) self.use_multinomial_sampling = use_multinomial_sampling self._max_tree_depth = 10 # from Stan # There are three conditions to stop doubling process: # + Tree is becoming too big. # + The trajectory is making a U-turn. # + The probability of the states becoming negligible: p(z, r) << u, # here u is the "slice" variable introduced at the `self.sample(...)` method. # Denote E_p = -log p(z, r), E_u = -log u, the third condition is equivalent to # sliced_energy := E_p - E_u > some constant =: max_sliced_energy. # This also suggests the notion "diverging" in the implemenation: # when the energy E_p diverges from E_u too much, we stop doubling. # Here, as suggested in [1], we set dE_max = 1000. self._max_sliced_energy = 1000 def _is_turning(self, r_left, r_right, r_sum): # We follow the strategy in Section A.4.2 of [2] for this implementation. r_left_flat =[r_left[site_name].reshape(-1) for site_name in sorted(r_left)]) r_right_flat =[r_right[site_name].reshape(-1) for site_name in sorted(r_right)]) # TODO: change to for pytorch 1.0 if self.full_mass: if (((r_sum - r_left_flat) * (self.inverse_mass_matrix.matmul(r_left_flat))) .sum() > 0 and ((r_sum - r_right_flat) * (self.inverse_mass_matrix.matmul(r_right_flat))) .sum() > 0): return False else: if ((self.inverse_mass_matrix * (r_sum - r_left_flat) * r_left_flat).sum() > 0 and (self.inverse_mass_matrix * (r_sum - r_right_flat) * r_right_flat).sum() > 0): return False return True def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): step_size = self.step_size if direction == 1 else -self.step_size z_new, r_new, z_grads, potential_energy = velocity_verlet( z, r, self._potential_energy, self.inverse_mass_matrix, step_size, z_grads=z_grads) r_new_flat =[r_new[site_name].reshape(-1) for site_name in sorted(r_new)]) energy_new = potential_energy + self._kinetic_energy(r_new) # handle the NaN case energy_new = energy_new.new_tensor(float("inf")) if torch_isnan(energy_new) else energy_new sliced_energy = energy_new + log_slice diverging = (sliced_energy > self._max_sliced_energy) delta_energy = energy_new - energy_current accept_prob = (-delta_energy).exp().clamp(max=1.0) if self.use_multinomial_sampling: tree_weight = -sliced_energy else: # As a part of the slice sampling process (see below), along the trajectory # we eliminate states which p(z, r) < u, or dE > 0. # Due to this elimination (and stop doubling conditions), # the weight of binary tree might not equal to 2^tree_depth. tree_weight = (sliced_energy.new_ones(()) if sliced_energy <= 0 else sliced_energy.new_zeros(())) return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, z_new, potential_energy, z_grads, r_new_flat, tree_weight, False, diverging, accept_prob, 1) def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_current): if tree_depth == 0: return self._build_basetree(z, r, z_grads, log_slice, direction, energy_current) # build the first half of tree half_tree = self._build_tree(z, r, z_grads, log_slice, direction, tree_depth-1, energy_current) z_proposal = half_tree.z_proposal z_proposal_pe = half_tree.z_proposal_pe z_proposal_grads = half_tree.z_proposal_grads # Check conditions to stop doubling. If we meet that condition, # there is no need to build the other tree. if half_tree.turning or half_tree.diverging: return half_tree # Else, build remaining half of tree. # If we are going to the right, start from the right leaf of the first half. if direction == 1: z = half_tree.z_right r = half_tree.r_right z_grads = half_tree.z_right_grads else: # otherwise, start from the left leaf of the first half z = half_tree.z_left r = half_tree.r_left z_grads = half_tree.z_left_grads other_half_tree = self._build_tree(z, r, z_grads, log_slice, direction, tree_depth-1, energy_current) if self.use_multinomial_sampling: tree_weight = logsumexp(torch.stack([half_tree.weight, other_half_tree.weight]), dim=0) else: tree_weight = half_tree.weight + other_half_tree.weight sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs num_proposals = half_tree.num_proposals + other_half_tree.num_proposals r_sum = half_tree.r_sum + other_half_tree.r_sum # The probability of that proposal belongs to which half of tree # is computed based on the weights of each half. if self.use_multinomial_sampling: other_half_tree_prob = (other_half_tree.weight - tree_weight).exp() else: # For the special case that the weights of each half are both 0, # we choose the proposal from the first half # (any is fine, because the probability of picking it at the end is 0!). other_half_tree_prob = (other_half_tree.weight / tree_weight if tree_weight > 0 else tree_weight.new_zeros(())) is_other_half_tree = pyro.sample("is_other_half_tree", dist.Bernoulli(probs=other_half_tree_prob)) if is_other_half_tree == 1: z_proposal = other_half_tree.z_proposal z_proposal_pe = other_half_tree.z_proposal_pe z_proposal_grads = other_half_tree.z_proposal_grads # leaves of the full tree are determined by the direction if direction == 1: z_left = half_tree.z_left r_left = half_tree.r_left z_left_grads = half_tree.z_left_grads z_right = other_half_tree.z_right r_right = other_half_tree.r_right z_right_grads = other_half_tree.z_right_grads else: z_left = other_half_tree.z_left r_left = other_half_tree.r_left z_left_grads = other_half_tree.z_left_grads z_right = half_tree.z_right r_right = half_tree.r_right z_right_grads = half_tree.z_right_grads # We already check if first half tree is turning. Now, we check # if the other half tree or full tree are turning. turning = other_half_tree.turning or self._is_turning(r_left, r_right, r_sum) # The divergence is checked by the second half tree (the first half is already checked). diverging = other_half_tree.diverging return _TreeInfo(z_left, r_left, z_left_grads, z_right, r_right, z_right_grads, z_proposal, z_proposal_pe, z_proposal_grads, r_sum, tree_weight, turning, diverging, sum_accept_probs, num_proposals)
[docs] def sample(self, trace): z = {name: node["value"].detach() for name, node in self._iter_latent_nodes(trace)} potential_energy, z_grads = self._fetch_from_cache() # automatically transform `z` to unconstrained space, if needed. for name, transform in self.transforms.items(): z[name] = transform(z[name]) r, r_flat = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r) + potential_energy if potential_energy is not None \ else self._energy(z, r) # Ideally, following a symplectic integrator trajectory, the energy is constant. # In that case, we can sample the proposal uniformly, and there is no need to use "slice". # However, it is not the case for real situation: there are errors during the computation. # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted # by u). # The sampling process goes as follows: # first sampling u from initial state (z_0, r_0) according to # u ~ Uniform(0, p(z_0, r_0)), # then sampling state (z, r) from the integrator trajectory according to # (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}). # # For more information about slice sampling method, see [3]. # For another version of NUTS which uses multinomial sampling instead of slice sampling, # see [2]. if self.use_multinomial_sampling: log_slice = -energy_current else: # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can # sample log_slice directly using `energy`, so as to avoid potential underflow or # overflow issues ([2]). slice_exp_term = pyro.sample("slicevar_exp_t={}".format(self._t), dist.Exponential(energy_current.new_tensor(1.))) log_slice = -energy_current - slice_exp_term z_left = z_right = z r_left = r_right = r z_left_grads = z_right_grads = z_grads accepted = False r_sum = r_flat if self.use_multinomial_sampling: tree_weight = energy_current.new_zeros(()) else: tree_weight = energy_current.new_ones(()) # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation. with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): # doubling process, stop when turning or diverging for tree_depth in range(self._max_tree_depth + 1): direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth), dist.Bernoulli(probs=torch.ones(1) * 0.5)) direction = int(direction.item()) if direction == 1: # go to the right, start from the right leaf of current tree new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice, direction, tree_depth, energy_current) # update leaf for the next doubling process z_right = new_tree.z_right r_right = new_tree.r_right z_right_grads = new_tree.z_right_grads else: # go the the left, start from the left leaf of current tree new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice, direction, tree_depth, energy_current) z_left = new_tree.z_left r_left = new_tree.r_left z_left_grads = new_tree.z_left_grads if new_tree.turning or new_tree.diverging: # stop doubling break if self.use_multinomial_sampling: new_tree_prob = (new_tree.weight - tree_weight).exp() else: new_tree_prob = new_tree.weight / tree_weight rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth), dist.Uniform(new_tree_prob.new_tensor(0.), new_tree_prob.new_tensor(1.))) if rand < new_tree_prob: accepted = True z = new_tree.z_proposal self._cache(new_tree.z_proposal_pe, new_tree.z_proposal_grads) r_sum = r_sum + new_tree.r_sum if self._is_turning(r_left, r_right, r_sum): # stop doubling break else: # update tree_weight if self.use_multinomial_sampling: tree_weight = logsumexp(torch.stack([tree_weight, new_tree.weight]), dim=0) else: tree_weight = tree_weight + new_tree.weight if self._t < self._warmup_steps: accept_prob = new_tree.sum_accept_probs / new_tree.num_proposals self._adapter.step(self._t, z, accept_prob) if accepted: self._accept_cnt += 1 self._t += 1 # get trace with the constrained values for `z`. for name, transform in self.transforms.items(): z[name] = transform.inv(z[name]) return self._get_trace(z)