Source code for pyro.infer.resampler

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, Dict, Optional

import torch

import pyro
import pyro.poutine as poutine
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import site_is_subsample


[docs]class Resampler: """Resampler for interactive tuning of generative models, typically when preforming prior predictive checks as an early step of Bayesian workflow. This is intended as a computational cache to speed up the interactive tuning of the parameters of prior distributions based on samples from a downstream simulation. The idea is that the simulation can be expensive, but that when one slightly tweaks parameters of the parameter distribution then one can reuse most of the previous samples via importance resampling. :param callable guide: A pyro model that takes no arguments. The guide should be diffuse, covering more space than the subsequent ``model`` passed to :meth:`sample`. Must be vectorizable via ``pyro.plate``. :param callable simulator: An optional larger pyro model with a superset of the guide's latent variables. Must be vectorizable via ``pyro.plate``. :param int num_guide_samples: Number of inital samples to draw from the guide. This should be much larger than the ``num_samples`` requested in subsequent calls to :meth:`sample`. :param int max_plate_nesting: The maximum plate nesting in the model. If absent this will be guessed by running the guide. """ def __init__( self, guide: Callable, simulator: Optional[Callable] = None, *, num_guide_samples: int, max_plate_nesting: Optional[int] = None, ): super().__init__() if max_plate_nesting is None: max_plate_nesting = _guess_max_plate_nesting( guide if simulator is None else simulator ) self._particle_dim = -1 - max_plate_nesting self._gumbels: Optional[torch.Tensor] = None # Draw samples from the initial guide. with pyro.plate("particles", num_guide_samples, dim=self._particle_dim): trace = poutine.trace(guide).get_trace() self._old_logp = _log_prob_sum(trace, num_guide_samples) if simulator: # Draw extended samples from the simulator. trace = poutine.trace(poutine.replay(simulator, trace)).get_trace() self._samples = { name: site["value"] for name, site in trace.nodes.items() if site["type"] == "sample" and not site_is_subsample(site) }
[docs] @torch.no_grad() def sample( self, model: Callable, num_samples: int, stable: bool = True ) -> Dict[str, torch.Tensor]: """Draws a set of at most ``num_samples`` many model samples, optionally extended by the ``simulator``. Internally this importance resamples the samples generated by the ``guide`` in ``.__init__()``, and does not rerun the ``guide`` or ``simulator``. If the original guide samples poorly cover the model distribution, samples will show low diversity. :param callable model: A model with the same latent variables as the original ``guide``. Must be vectorizable via ``pyro.plate``. :param int num_samples: The number of samples to draw. :param bool stable: Whether to use piecewise-constant multinomial sampling. Set to True for visualization, False for Monte Carlo integration. Defaults to True. :returns: A dictionary of stacked samples. :rtype: Dict[str, torch.Tensor] """ num_guide_samples = len(self._old_logp) with pyro.plate("particles", num_guide_samples, dim=self._particle_dim): trace = poutine.trace(poutine.condition(model, self._samples)).get_trace() new_logp = _log_prob_sum(trace, num_guide_samples) logits = new_logp - self._old_logp i = self._categorical_sample(logits, num_samples, stable) samples = {k: v[i] for k, v in self._samples.items()} return samples
def _categorical_sample( self, logits: torch.Tensor, num_samples: int, stable: bool ) -> torch.Tensor: if not stable: return torch.multinomial(logits.exp(), num_samples, replacement=True) # Implement stable categorical sampling via the Gumbel-max trick. if self._gumbels is None or len(self._gumbels) < num_samples: # gumbel ~ -log(-log(uniform(0,1))) tiny = torch.finfo(logits.dtype).tiny self._gumbels = logits.new_empty(num_samples, len(logits)).uniform_() self._gumbels.clamp_(min=tiny).log_().neg_().clamp_(min=tiny).log_().neg_() return self._gumbels[:num_samples].add(logits).max(-1).indices
def _log_prob_sum(trace: Trace, batch_size: int) -> torch.Tensor: """Computes vectorized log_prob_sum batched over the leftmost dimension.""" trace.compute_log_prob() result = 0.0 for site in trace.nodes.values(): if site["type"] == "sample": logp = site["log_prob"] assert logp.shape[:1] == (batch_size,) result += logp.reshape(batch_size, -1).sum(-1) return result def _guess_max_plate_nesting(model: callable) -> int: with torch.no_grad(), poutine.block(), poutine.mask(mask=False): trace = poutine.trace(model).get_trace() plate_nesting = {0}.union( -f.dim for site in trace.nodes.values() for f in site.get("cond_indep_stack", []) if f.vectorized ) return max(plate_nesting)