Source code for pyro.infer.predictive

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import warnings
from dataclasses import dataclass, fields
from functools import reduce
from typing import Callable, List, Union

import torch

import pyro
import pyro.poutine as poutine
from pyro.infer.importance import LogWeightsMixin
from pyro.infer.util import CloneMixin, plate_log_prob_sum
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites


def _guess_max_plate_nesting(model, args, kwargs):
    """
    Guesses max_plate_nesting by running the model once
    without enumeration. This optimistically assumes static model
    structure.
    """
    with poutine.block():
        model_trace = poutine.trace(model).get_trace(*args, **kwargs)
    sites = [site for site in model_trace.nodes.values() if site["type"] == "sample"]

    dims = [
        frame.dim
        for site in sites
        for frame in site["cond_indep_stack"]
        if frame.vectorized
    ]
    max_plate_nesting = -min(dims) if dims else 0
    return max_plate_nesting


@dataclass(frozen=True, eq=False)
class _predictiveResults:
    """
    Return value of call to ``_predictive`` and ``_predictive_sequential``.
    """

    samples: dict
    trace: Union[Trace, List[Trace]]


def _predictive_sequential(
    model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes
):
    collected_samples = []
    collected_trace = []
    samples = [
        {k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples)
    ]
    for i in range(num_samples):
        trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(
            *model_args, **model_kwargs
        )
        collected_trace.append(trace)
        collected_samples.append(
            {site: trace.nodes[site]["value"] for site in return_site_shapes}
        )

    return _predictiveResults(
        trace=collected_trace,
        samples={
            site: torch.stack([s[site] for s in collected_samples]).reshape(shape)
            for site, shape in return_site_shapes.items()
        },
    )


_predictive_vectorize_plate_name = "_num_predictive_samples"


def _predictive(
    model,
    posterior_samples,
    num_samples,
    return_sites=(),
    parallel=False,
    model_args=(),
    model_kwargs={},
    mask=True,
):
    model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model)
    max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
    vectorize = pyro.plate(
        _predictive_vectorize_plate_name, num_samples, dim=-max_plate_nesting - 1
    )
    model_trace = prune_subsample_sites(
        poutine.trace(model).get_trace(*model_args, **model_kwargs)
    )
    reshaped_samples = {}

    for name, sample in posterior_samples.items():
        sample_shape = sample.shape[1:]
        sample = sample.reshape(
            (num_samples,)
            + (1,) * (max_plate_nesting - len(sample_shape))
            + sample_shape
        )
        reshaped_samples[name] = sample

    return_site_shapes = {}
    for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
        append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape)
        site_shape = (
            (num_samples,) + (1,) * append_ndim + model_trace.nodes[site]["value"].shape
        )
        # non-empty return-sites
        if return_sites:
            if site in return_sites:
                return_site_shapes[site] = site_shape
        # special case (for guides): include all sites
        elif return_sites is None:
            return_site_shapes[site] = site_shape
        # default case: return sites = ()
        # include all sites not in posterior samples
        elif site not in posterior_samples:
            return_site_shapes[site] = site_shape

    # handle _RETURN site
    if return_sites is not None and "_RETURN" in return_sites:
        value = model_trace.nodes["_RETURN"]["value"]
        shape = (num_samples,) + value.shape if torch.is_tensor(value) else None
        return_site_shapes["_RETURN"] = shape

    if not parallel:
        return _predictive_sequential(
            model,
            posterior_samples,
            model_args,
            model_kwargs,
            num_samples,
            return_site_shapes,
        )

    trace = poutine.trace(
        poutine.condition(vectorize(model), reshaped_samples)
    ).get_trace(*model_args, **model_kwargs)
    predictions = {}
    for site, shape in return_site_shapes.items():
        value = trace.nodes[site]["value"]
        if site == "_RETURN" and shape is None:
            predictions[site] = value
            continue
        if value.numel() < reduce((lambda x, y: x * y), shape):
            predictions[site] = value.expand(shape)
        else:
            predictions[site] = value.reshape(shape)

    return _predictiveResults(trace=trace, samples=predictions)


[docs]class Predictive(torch.nn.Module): """ EXPERIMENTAL class used to construct predictive distribution. The predictive distribution is obtained by running the `model` conditioned on latent samples from `posterior_samples`. If a `guide` is provided, then posterior samples from all the latent sites are also returned. .. warning:: The interface for the :class:`Predictive` class is experimental, and might change in the future. :param model: Python callable containing Pyro primitives. :param dict posterior_samples: dictionary of samples from the posterior. :param callable guide: optional guide to get posterior samples of sites not present in `posterior_samples`. :param int num_samples: number of samples to draw from the predictive distribution. This argument has no effect if ``posterior_samples`` is non-empty, in which case, the leading dimension size of samples in ``posterior_samples`` is used. :param return_sites: sites to return; by default only sample sites not present in `posterior_samples` are returned. :type return_sites: list, tuple, or set :param bool parallel: predict in parallel by wrapping the existing model in an outermost `plate` messenger. Note that this requires that the model has all batch dims correctly annotated via :class:`~pyro.plate`. Default is `False`. """ def __init__( self, model, posterior_samples=None, guide=None, num_samples=None, return_sites=(), parallel=False, ): super().__init__() if posterior_samples is None: if num_samples is None: raise ValueError( "Either posterior_samples or num_samples must be specified." ) posterior_samples = {} for name, sample in posterior_samples.items(): batch_size = sample.shape[0] if num_samples is None: num_samples = batch_size elif num_samples != batch_size: warnings.warn( "Sample's leading dimension size {} is different from the " "provided {} num_samples argument. Defaulting to {}.".format( batch_size, num_samples, batch_size ), UserWarning, ) num_samples = batch_size if num_samples is None: raise ValueError( "No sample sites in posterior samples to infer `num_samples`." ) if guide is not None and posterior_samples: raise ValueError( "`posterior_samples` cannot be provided with the `guide` argument." ) if return_sites is not None: assert isinstance(return_sites, (list, tuple, set)) self.model = model self.posterior_samples = {} if posterior_samples is None else posterior_samples self.num_samples = num_samples self.guide = guide self.return_sites = return_sites self.parallel = parallel
[docs] def call(self, *args, **kwargs): """ Method that calls :meth:`forward` and returns parameter values of the guide as a `tuple` instead of a `dict`, which is a requirement for JIT tracing. Unlike :meth:`forward`, this method can be traced by :func:`torch.jit.trace_module`. .. warning:: This method may be removed once PyTorch JIT tracer starts accepting `dict` as valid return types. See `issue <https://github.com/pytorch/pytorch/issues/27743>`_. """ result = self.forward(*args, **kwargs) return tuple(v for _, v in sorted(result.items()))
[docs] def forward(self, *args, **kwargs): """ Returns dict of samples from the predictive distribution. By default, only sample sites not contained in `posterior_samples` are returned. This can be modified by changing the `return_sites` keyword argument of this :class:`Predictive` instance. .. note:: This method is used internally by :class:`~torch.nn.Module`. Users should instead use :meth:`~torch.nn.Module.__call__` as in ``Predictive(model)(*args, **kwargs)``. :param args: model arguments. :param kwargs: model keyword arguments. """ posterior_samples = self.posterior_samples return_sites = self.return_sites if self.guide is not None: # return all sites by default if a guide is provided. return_sites = None if not return_sites else return_sites posterior_samples = _predictive( self.guide, posterior_samples, self.num_samples, return_sites=None, parallel=self.parallel, model_args=args, model_kwargs=kwargs, ).samples return _predictive( self.model, posterior_samples, self.num_samples, return_sites=return_sites, parallel=self.parallel, model_args=args, model_kwargs=kwargs, ).samples
[docs] def get_samples(self, *args, **kwargs): warnings.warn( "The method `.get_samples` has been deprecated in favor of `.forward`.", DeprecationWarning, ) return self.forward(*args, **kwargs)
[docs] def get_vectorized_trace(self, *args, **kwargs): """ Returns a single vectorized `trace` from the predictive distribution. Note that this requires that the model has all batch dims correctly annotated via :class:`~pyro.plate`. :param args: model arguments. :param kwargs: model keyword arguments. """ posterior_samples = self.posterior_samples if self.guide is not None: posterior_samples = _predictive( self.guide, posterior_samples, self.num_samples, parallel=self.parallel, model_args=args, model_kwargs=kwargs, ).samples return _predictive( self.model, posterior_samples, self.num_samples, parallel=True, model_args=args, model_kwargs=kwargs, ).trace
[docs]@dataclass(frozen=True, eq=False) class WeighedPredictiveResults(LogWeightsMixin, CloneMixin): """ Return value of call to instance of :class:`WeighedPredictive`. """ samples: Union[dict, tuple] log_weights: torch.Tensor guide_log_prob: torch.Tensor model_log_prob: torch.Tensor
[docs]class WeighedPredictive(Predictive): """ Class used to construct a weighed predictive distribution that is based on the same initialization interface as :class:`Predictive`. The methods `.forward` and `.call` can be called with an additional keyword argument ``model_guide`` which is the model used to create and optimize the guide (if not provided ``model_guide`` defaults to ``self.model``), and they return both samples and log_weights. The weights are calculated as the per sample gap between the model_guide log-probability and the guide log-probability (a guide must always be provided). A typical use case would be based on a ``model`` :math:`p(x,z)=p(x|z)p(z)` and ``guide`` :math:`q(z)` that has already been fitted to the model given observations :math:`p(X_{obs},z)`, both of which are provided at itialization of :class:`WeighedPredictive` (same as you would do with :class:`Predictive`). When calling an instance of :class:`WeighedPredictive` we provide the model given observations :math:`p(X_{obs},z)` as the keyword argument ``model_guide``. The resulting output would be the usual samples :math:`p(x|z)q(z)` returned by :class:`Predictive`, along with per sample weights :math:`p(X_{obs},z)/q(z)`. The samples and weights can be fed into :any:`weighed_quantile` in order to obtain the true quantiles of the resulting distribution. Note that the ``model`` can be more elaborate with sample sites :math:`y` that are not observed and are not part of the guide, if the samples sites :math:`y` are sampled after the observations and the latent variables sampled by the guide, such that :math:`p(x,y,z)=p(y|x,z)p(x|z)p(z)` where each element in the product represents a set of ``pyro.sample`` statements. """
[docs] def call(self, *args, **kwargs): """ Method `.call` that is backwards compatible with the same method found in :class:`Predictive` but can be called with an additional keyword argument `model_guide` which is the model used to create and optimize the guide. Returns :class:`WeighedPredictiveResults` which has attributes ``.samples`` and per sample weights ``.log_weights``. """ result = self.forward(*args, **kwargs) return WeighedPredictiveResults( samples=tuple(v for _, v in sorted(result.items())), log_weights=result.log_weights, guide_log_prob=result.guide_log_prob, model_log_prob=result.model_log_prob, )
[docs] def forward(self, *args, **kwargs): """ Method `.forward` that is backwards compatible with the same method found in :class:`Predictive` but can be called with an additional keyword argument `model_guide` which is the model used to create and optimize the guide. Returns :class:`WeighedPredictiveResults` which has attributes ``.samples`` and per sample weights ``.log_weights``. """ model_guide = kwargs.pop("model_guide", self.model) return_sites = self.return_sites # return all sites by default if a guide is provided. return_sites = None if not return_sites else return_sites guide_predictive = _predictive( self.guide, self.posterior_samples, self.num_samples, return_sites=None, parallel=self.parallel, model_args=args, model_kwargs=kwargs, mask=False, ) posterior_samples = guide_predictive.samples model_predictive = _predictive( model_guide, posterior_samples, self.num_samples, return_sites=return_sites, parallel=self.parallel, model_args=args, model_kwargs=kwargs, mask=False, ) if not isinstance(guide_predictive.trace, list): guide_trace = prune_subsample_sites(guide_predictive.trace) model_trace = prune_subsample_sites(model_predictive.trace) guide_trace.compute_score_parts() model_trace.compute_log_prob() guide_trace.pack_tensors() model_trace.pack_tensors(guide_trace.plate_to_symbol) plate_symbol = guide_trace.plate_to_symbol[_predictive_vectorize_plate_name] guide_log_prob = plate_log_prob_sum(guide_trace, plate_symbol) model_log_prob = plate_log_prob_sum(model_trace, plate_symbol) else: guide_log_prob = torch.Tensor( [ trace_element.log_prob_sum() for trace_element in guide_predictive.trace ] ) model_log_prob = torch.Tensor( [ trace_element.log_prob_sum() for trace_element in model_predictive.trace ] ) return WeighedPredictiveResults( samples=( _predictive( self.model, posterior_samples, self.num_samples, return_sites=return_sites, parallel=self.parallel, model_args=args, model_kwargs=kwargs, ).samples if model_guide is not self.model else model_predictive.samples ), log_weights=model_log_prob - guide_log_prob, guide_log_prob=guide_log_prob, model_log_prob=model_log_prob, )
[docs]class MHResampler(torch.nn.Module): r""" Resampler for weighed samples that generates equally weighed samples from the distribution specified by the weighed samples ``sampler``. The resampling is based on the Metropolis-Hastings algorithm. Given an initial sample :math:`x` subsequent samples are generated by: - Sampling from the ``guide`` a new sample candidate :math:`x'` with probability :math:`g(x')`. - Calculate an acceptance probability :math:`A(x', x) = \min\left(1, \frac{P(x')}{P(x)} \frac{g(x)}{g(x')}\right)` with :math:`P` being the ``model``. - With probability :math:`A(x', x)` accept the new sample candidate :math:`x'` as the next sample, otherwise set the current sample :math:`x` as the next sample. The above is the Metropolis-Hastings algorithm with the new sample candidate proposal distribution being equal to the ``guide`` and independent of the current sample such that :math:`g(x')=g(x' \mid x)`. :param callable sampler: When called returns :class:`WeighedPredictiveResults`. :param slice source_samples_slice: Select source samples for storage (default is `slice(0)`, i.e. none). :param slice stored_samples_slice: Select output samples for storage (default is `slice(0)`, i.e. none). The typical use case of :class:`MHResampler` would be to convert weighed samples generated by :class:`WeighedPredictive` into equally weighed samples from the target distribution. Each time an instance of :class:`MHResampler` is called it returns a new set of samples, with the samples generated by the first call being distributed according to the ``guide``, and with each subsequent call the distribution of the samples becomes closer to that of the posterior predictive disdtribution. It might take some experimentation in order to find out in each case how many times one would need to call an instance of :class:`MHResampler` in order to be close enough to the posterior predictive distribution. Example:: def model(): ... def guide(): ... def conditioned_model(): ... # Fit guide elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=3.0)), elbo) for i in range(num_svi_steps): svi.step() # Create callable that returns weighed samples posterior_predictive = WeighedPredictive(model, guide=guide, num_samples=num_samples, parallel=parallel, return_sites=["_RETURN"]) prob = 0.95 weighed_samples = posterior_predictive(model_guide=conditioned_model) # Calculate quantile directly from weighed samples weighed_samples_quantile = weighed_quantile(weighed_samples.samples['_RETURN'], [prob], weighed_samples.log_weights)[0] resampler = MHResampler(posterior_predictive) num_mh_steps = 10 for mh_step_count in range(num_mh_steps): resampled_weighed_samples = resampler(model_guide=conditioned_model) # Calculate quantile from resampled weighed samples (samples are equally weighed) resampled_weighed_samples_quantile = quantile(resampled_weighed_samples.samples[`_RETURN`], [prob])[0] # Quantiles calculated using both methods should be identical assert_close(weighed_samples_quantile, resampled_weighed_samples_quantile, rtol=0.01) .. _mhsampler-behavior: **Notes on Sampler Behavior:** - In case the ``guide`` perfectly tracks the ``model`` this sampler will do nothing as the acceptance probability :math:`A(x', x)` will always be one. - Furtheremore, if the guide is approximately separable, i.e. :math:`g(z_A, z_B) \approx g_A(z_A) g_B(z_B)`, with :math:`g_A(z_A)` pefectly tracking the ``model`` and :math:`g_B(z_B)` poorly tracking the ``model``, quantiles of :math:`z_A` calculated from samples taken from :class:`MHResampler`, will have much lower variance then quantiles of :math:`z_A` calculated by using :any:`weighed_quantile`, as the effective sample size of the calculation using :any:`weighed_quantile` will be low due to :math:`g_B(z_B)` poorly tracking the ``model``, whereas when using :class:`MHResampler` the poor ``model`` tracking of :math:`g_B(z_B)` has negligible affect on the effective sample size of :math:`z_A` samples. """ def __init__( self, sampler: Callable, source_samples_slice: slice = slice(0), stored_samples_slice: slice = slice(0), ): super().__init__() self.sampler = sampler self.samples = None self.transition_count = torch.tensor(0, dtype=torch.long) self.source_samples = [] self.source_samples_slice = source_samples_slice self.stored_samples = [] self.stored_samples_slice = stored_samples_slice
[docs] def forward(self, *args, **kwargs): """ Perform single resampling step. Returns :class:`WeighedPredictiveResults` """ with torch.no_grad(): new_samples = self.sampler(*args, **kwargs) # Store samples self.source_samples.append(new_samples) self.source_samples = self.source_samples[self.source_samples_slice] if self.samples is None: # First set of samples self.samples = new_samples.clone() self.transition_count = torch.zeros_like( new_samples.log_weights, dtype=torch.long ) else: # Apply Metropolis-Hastings algorithm prob = torch.clamp( new_samples.log_weights - self.samples.log_weights, max=0.0 ).exp() idx = torch.rand(*prob.shape) <= prob self.transition_count[idx] += 1 for field_desc in fields(self.samples): field, new_field = getattr(self.samples, field_desc.name), getattr( new_samples, field_desc.name ) if isinstance(field, dict): for key in field: field[key][idx] = new_field[key][idx] else: field[idx] = new_field[idx] self.stored_samples.append(self.samples.clone()) self.stored_samples = self.stored_samples[self.stored_samples_slice] return self.samples
[docs] def get_min_sample_transition_count(self): """ Return transition count of sample with minimal amount of transitions. """ return self.transition_count.min()
[docs] def get_total_transition_count(self): """ Return total number of transitions. """ return self.transition_count.sum()
[docs] def get_source_samples(self): """ Return source samples that were the input to the Metropolis-Hastings algorithm. """ return self.get_samples(self.source_samples)
[docs] def get_stored_samples(self): """ Return stored samples that were the output of the Metropolis-Hastings algorithm. """ return self.get_samples(self.stored_samples)
[docs] def get_samples(self, samples): """ Return samples that were sampled during execution of the Metropolis-Hastings algorithm. """ retval = dict() for field_desc in fields(self.samples): field_name, value = field_desc.name, getattr(self.samples, field_desc.name) if isinstance(value, dict): retval[field_name] = dict() for key in value: retval[field_name][key] = torch.cat( [getattr(sample, field_name)[key] for sample in samples] ) else: retval[field_name] = torch.cat( [getattr(sample, field_name) for sample in samples] ) return self.samples.__class__(**retval)