Source code for pyro.infer.mcmc.util

import functools
import warnings
from collections import OrderedDict, defaultdict
from functools import partial, reduce
from itertools import product
import traceback as tb

import torch
from torch.distributions import biject_to
from opt_einsum import shared_intermediates

import pyro
import pyro.poutine as poutine
import pyro.distributions as dist
from pyro.distributions.util import broadcast_shape, logsumexp
from pyro.infer import config_enumerate
from pyro.infer.util import is_validation_enabled
from pyro.ops import stats
from pyro.ops.contract import contract_to_tensor
from pyro.ops.integrator import potential_grad
from pyro.poutine.subsample_messenger import _Subsample
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_site_shape, ignore_jit_warnings


class TraceTreeEvaluator(object):
    """
    Computes the log probability density of a trace (of a model with
    tree structure) that possibly contains discrete sample sites
    enumerated in parallel. This will be deprecated in favor of
    :class:`~pyro.infer.mcmc.util.EinsumTraceProbEvaluator`.

    :param model_trace: execution trace from a static model.
    :param bool has_enumerable_sites: whether the trace contains any
        discrete enumerable sites.
    :param int max_plate_nesting: Optional bound on max number of nested
        :func:`pyro.plate` contexts.
    """
    def __init__(self,
                 model_trace,
                 has_enumerable_sites=False,
                 max_plate_nesting=None):
        self.has_enumerable_sites = has_enumerable_sites
        self.max_plate_nesting = max_plate_nesting
        # To be populated using the model trace once.
        self._log_probs = defaultdict(list)
        self._log_prob_shapes = defaultdict(tuple)
        self._children = defaultdict(list)
        self._enum_dims = {}
        self._plate_dims = {}
        self._parse_model_structure(model_trace)

    def _parse_model_structure(self, model_trace):
        if not self.has_enumerable_sites:
            return
        if self.max_plate_nesting is None:
            raise ValueError("Finite value required for `max_plate_nesting` when model "
                             "has discrete (enumerable) sites.")
        self._compute_log_prob_terms(model_trace)
        # 1. Infer model structure - compute parent-child relationship.
        sorted_ordinals = sorted(self._log_probs.keys())
        for i, child_node in enumerate(sorted_ordinals):
            for j in range(i-1, -1, -1):
                cur_node = sorted_ordinals[j]
                if cur_node < child_node:
                    self._children[cur_node].append(child_node)
                    break  # at most 1 parent.
        # 2. Populate `plate_dims` and `enum_dims` to be evaluated/
        #    enumerated out at each ordinal.
        self._populate_cache(frozenset(), frozenset(), set())

    def _populate_cache(self, ordinal, parent_ordinal, parent_enum_dims):
        """
        For each ordinal, populate the `plate` and `enum` dims to be
        evaluated or enumerated out.
        """
        log_prob_shape = self._log_prob_shapes[ordinal]
        plate_dims = sorted([frame.dim for frame in ordinal - parent_ordinal])
        enum_dims = set((i for i in range(-len(log_prob_shape), -self.max_plate_nesting)
                         if log_prob_shape[i] > 1))
        self._plate_dims[ordinal] = plate_dims
        self._enum_dims[ordinal] = set(enum_dims - parent_enum_dims)
        for c in self._children[ordinal]:
            self._populate_cache(c, ordinal, enum_dims)

    def _compute_log_prob_terms(self, model_trace):
        """
        Computes the conditional probabilities for each of the sites
        in the model trace, and stores the result in `self._log_probs`.
        """
        model_trace.compute_log_prob()
        self._log_probs = defaultdict(list)
        ordering = {name: frozenset(site["cond_indep_stack"])
                    for name, site in model_trace.nodes.items()
                    if site["type"] == "sample"}
        # Collect log prob terms per independence context.
        for name, site in model_trace.nodes.items():
            if site["type"] == "sample":
                if is_validation_enabled():
                    check_site_shape(site, self.max_plate_nesting)
                self._log_probs[ordering[name]].append(site["log_prob"])
        if not self._log_prob_shapes:
            for ordinal, log_prob in self._log_probs.items():
                self._log_prob_shapes[ordinal] = broadcast_shape(*(t.shape for t in self._log_probs[ordinal]))

    def _reduce(self, ordinal, agg_log_prob=torch.tensor(0.)):
        """
        Reduce the log prob terms for the given ordinal:
          - taking log_sum_exp of factors in enum dims (i.e.
            adding up the probability terms).
          - summing up the dims within `max_plate_nesting`.
            (i.e. multiplying probs within independent batches).

        :param ordinal: node (ordinal)
        :param torch.Tensor agg_log_prob: aggregated `log_prob`
            terms from the downstream nodes.
        :return: `log_prob` with marginalized `plate` and `enum`
            dims.
        """
        log_prob = sum(self._log_probs[ordinal]) + agg_log_prob
        for enum_dim in self._enum_dims[ordinal]:
            log_prob = logsumexp(log_prob, dim=enum_dim, keepdim=True)
        for marginal_dim in self._plate_dims[ordinal]:
            log_prob = log_prob.sum(dim=marginal_dim, keepdim=True)
        return log_prob

    def _aggregate_log_probs(self, ordinal):
        """
        Aggregate the `log_prob` terms using depth first search.
        """
        if not self._children[ordinal]:
            return self._reduce(ordinal)
        agg_log_prob = sum(map(self._aggregate_log_probs, self._children[ordinal]))
        return self._reduce(ordinal, agg_log_prob)

    def log_prob(self, model_trace):
        """
        Returns the log pdf of `model_trace` by appropriately handling
        enumerated log prob factors.

        :return: log pdf of the trace.
        """
        with shared_intermediates():
            if not self.has_enumerable_sites:
                return model_trace.log_prob_sum()
            self._compute_log_prob_terms(model_trace)
            return self._aggregate_log_probs(ordinal=frozenset()).sum()


class TraceEinsumEvaluator(object):
    """
    Computes the log probability density of a trace (of a model with
    tree structure) that possibly contains discrete sample sites
    enumerated in parallel. This uses optimized `einsum` operations
    to marginalize out the the enumerated dimensions in the trace
    via :class:`~pyro.ops.contract.contract_to_tensor`.

    :param model_trace: execution trace from a static model.
    :param bool has_enumerable_sites: whether the trace contains any
        discrete enumerable sites.
    :param int max_plate_nesting: Optional bound on max number of nested
        :func:`pyro.plate` contexts.
    """
    def __init__(self,
                 model_trace,
                 has_enumerable_sites=False,
                 max_plate_nesting=None):
        self.has_enumerable_sites = has_enumerable_sites
        self.max_plate_nesting = max_plate_nesting
        # To be populated using the model trace once.
        self._enum_dims = set()
        self.ordering = {}
        self._populate_cache(model_trace)

    def _populate_cache(self, model_trace):
        """
        Populate the ordinals (set of ``CondIndepStack`` frames)
        and enum_dims for each sample site.
        """
        if not self.has_enumerable_sites:
            return
        if self.max_plate_nesting is None:
            raise ValueError("Finite value required for `max_plate_nesting` when model "
                             "has discrete (enumerable) sites.")
        model_trace.compute_log_prob()
        model_trace.pack_tensors()
        for name, site in model_trace.nodes.items():
            if site["type"] == "sample" and not isinstance(site["fn"], _Subsample):
                if is_validation_enabled():
                    check_site_shape(site, self.max_plate_nesting)
                self.ordering[name] = frozenset(model_trace.plate_to_symbol[f.name]
                                                for f in site["cond_indep_stack"]
                                                if f.vectorized)
        self._enum_dims = set(model_trace.symbol_to_dim) - set(model_trace.plate_to_symbol.values())

    def _get_log_factors(self, model_trace):
        """
        Aggregates the `log_prob` terms into a list for each
        ordinal.
        """
        model_trace.compute_log_prob()
        model_trace.pack_tensors()
        log_probs = OrderedDict()
        # Collect log prob terms per independence context.
        for name, site in model_trace.nodes.items():
            if site["type"] == "sample" and not isinstance(site["fn"], _Subsample):
                if is_validation_enabled():
                    check_site_shape(site, self.max_plate_nesting)
                log_probs.setdefault(self.ordering[name], []).append(site["packed"]["log_prob"])
        return log_probs

    def log_prob(self, model_trace):
        """
        Returns the log pdf of `model_trace` by appropriately handling
        enumerated log prob factors.

        :return: log pdf of the trace.
        """
        if not self.has_enumerable_sites:
            return model_trace.log_prob_sum()
        log_probs = self._get_log_factors(model_trace)
        with shared_intermediates() as cache:
            return contract_to_tensor(log_probs, self._enum_dims, cache=cache)


def _guess_max_plate_nesting(model, args, kwargs):
    """
    Guesses max_plate_nesting by running the model once
    without enumeration. This optimistically assumes static model
    structure.
    """
    with poutine.block():
        model_trace = poutine.trace(model).get_trace(*args, **kwargs)
    sites = [site for site in model_trace.nodes.values()
             if site["type"] == "sample"]

    dims = [frame.dim
            for site in sites
            for frame in site["cond_indep_stack"]
            if frame.vectorized]
    max_plate_nesting = -min(dims) if dims else 0
    return max_plate_nesting


class _PEMaker(object):
    def __init__(self, model, model_args, model_kwargs, trace_prob_evaluator, transforms):
        self.model = model
        self.model_args = model_args
        self.model_kwargs = model_kwargs
        self.trace_prob_evaluator = trace_prob_evaluator
        self.transforms = transforms
        self._compiled_fn = None

    def _potential_fn(self, params):
        params_constrained = {k: self.transforms[k].inv(v) for k, v in params.items()}
        cond_model = poutine.condition(self.model, params_constrained)
        model_trace = poutine.trace(cond_model).get_trace(*self.model_args,
                                                          **self.model_kwargs)
        log_joint = self.trace_prob_evaluator.log_prob(model_trace)
        for name, t in self.transforms.items():
            log_joint = log_joint - torch.sum(
                t.log_abs_det_jacobian(params_constrained[name], params[name]))
        return -log_joint

    def _potential_fn_jit(self, skip_jit_warnings, jit_options, params):
        if not params:
            return self._potential_fn(params)
        names, vals = zip(*sorted(params.items()))

        if self._compiled_fn:
            return self._compiled_fn(*vals)

        with pyro.validation_enabled(False):
            def _pe_jit(*zi):
                params = dict(zip(names, zi))
                return self._potential_fn(params)

            if skip_jit_warnings:
                _pe_jit = ignore_jit_warnings()(_pe_jit)
            self._compiled_fn = torch.jit.trace(_pe_jit, vals, **jit_options)
            return self._compiled_fn(*vals)

    def get_potential_fn(self, jit_compile=False, skip_jit_warnings=True, jit_options=None):
        if jit_compile:
            jit_options = {"check_trace": False} if jit_options is None else jit_options
            return partial(self._potential_fn_jit, skip_jit_warnings, jit_options)
        return self._potential_fn


# TODO: expose init_strategy using separate functions.
def _get_init_params(model, model_args, model_kwargs, transforms, potential_fn, prototype_params,
                     max_tries_initial_params=100, num_chains=1, strategy="uniform"):
    params = prototype_params
    params_per_chain = defaultdict(list)
    n = 0

    # For empty models, exit early
    if not params:
        return params

    for i in range(max_tries_initial_params):
        while n < num_chains:
            if strategy == "uniform":
                params = {k: dist.Uniform(v.new_full(v.shape, -2), v.new_full(v.shape, 2)).sample()
                          for k, v in params.items()}
            elif strategy == "prior":
                trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
                samples = {name: trace.nodes[name]["value"].detach() for name in params}
                params = {k: transforms[k](v) for k, v in samples.items()}
            pe_grad, pe = potential_grad(potential_fn, params)

            if torch.isfinite(pe) and all(map(torch.all, map(torch.isfinite, pe_grad.values()))):
                for k, v in params.items():
                    params_per_chain[k].append(v)
                n += 1
        if num_chains == 1:
            return {k: v[0] for k, v in params_per_chain.items()}
        else:
            return {k: torch.stack(v) for k, v in params_per_chain.items()}
    raise ValueError("Model specification seems incorrect - cannot find valid initial params.")


[docs]def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max_plate_nesting=None, jit_compile=False, jit_options=None, skip_jit_warnings=False, num_chains=1): """ Given a Python callable with Pyro primitives, generates the following model-specific properties needed for inference using HMC/NUTS kernels: - initial parameters to be sampled using a HMC kernel, - a potential function whose input is a dict of parameters in unconstrained space, - transforms to transform latent sites of `model` to unconstrained space, - a prototype trace to be used in MCMC to consume traces from sampled parameters. :param model: a Pyro model which contains Pyro primitives. :param tuple model_args: optional args taken by `model`. :param dict model_kwargs: optional kwargs taken by `model`. :param dict transforms: Optional dictionary that specifies a transform for a sample site with constrained support to unconstrained space. The transform should be invertible, and implement `log_abs_det_jacobian`. If not specified and the model has sites with constrained support, automatic transformations will be applied, as specified in :mod:`torch.distributions.constraint_registry`. :param int max_plate_nesting: Optional bound on max number of nested :func:`pyro.plate` contexts. This is required if model contains discrete sample sites that can be enumerated over in parallel. :param bool jit_compile: Optional parameter denoting whether to use the PyTorch JIT to trace the log density computation, and use this optimized executable trace in the integrator. :param dict jit_options: A dictionary contains optional arguments for :func:`torch.jit.trace` function. :param bool ignore_jit_warnings: Flag to ignore warnings from the JIT tracer when ``jit_compile=True``. Default is False. :param int num_chains: Number of parallel chains. If `num_chains > 1`, the returned `initial_params` will be a list with `num_chains` elements. :returns: a tuple of (`initial_params`, `potential_fn`, `transforms`, `prototype_trace`) """ # XXX `transforms` domains are sites' supports # FIXME: find a good pattern to deal with `transforms` arg if transforms is None: automatic_transform_enabled = True transforms = {} else: automatic_transform_enabled = False if max_plate_nesting is None: max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) # Wrap model in `poutine.enum` to enumerate over discrete latent sites. # No-op if model does not have any discrete latents. model = poutine.enum(config_enumerate(model), first_available_dim=-1 - max_plate_nesting) model_trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) has_enumerable_sites = False prototype_samples = {} for name, node in model_trace.iter_stochastic_nodes(): if isinstance(node["fn"], _Subsample): continue if node["fn"].has_enumerate_support: has_enumerable_sites = True continue # we need to detach here because this sample can be a leaf variable, # so we can't change its requires_grad flag to calculate its grad in # velocity_verlet prototype_samples[name] = node["value"].detach() if automatic_transform_enabled: transforms[name] = biject_to(node["fn"].support).inv trace_prob_evaluator = TraceEinsumEvaluator(model_trace, has_enumerable_sites, max_plate_nesting) prototype_params = {k: transforms[k](v) for k, v in prototype_samples.items()} pe_maker = _PEMaker(model, model_args, model_kwargs, trace_prob_evaluator, transforms) # Note that we deliberately do not exercise jit compilation here so as to # enable potential_fn to be picklable (a torch._C.Function cannot be pickled). init_params = _get_init_params(model, model_args, model_kwargs, transforms, pe_maker.get_potential_fn(), prototype_params, num_chains=num_chains) potential_fn = pe_maker.get_potential_fn(jit_compile, skip_jit_warnings, jit_options) return init_params, potential_fn, transforms, model_trace
def _safe(fn): """ Safe version of utilities in the :mod:`pyro.ops.stats` module. Wrapped functions return `NaN` tensors instead of throwing exceptions. :param fn: stats function from :mod:`pyro.ops.stats` module. """ @functools.wraps(fn) def wrapped(sample, *args, **kwargs): try: val = fn(sample, *args, **kwargs) except Exception: warnings.warn(tb.format_exc()) val = torch.full(sample.shape[2:], float("nan"), dtype=sample.dtype, device=sample.device) return val return wrapped
[docs]def diagnostics(samples, group_by_chain=True): """ Gets diagnostics statistics such as effective sample size and split Gelman-Rubin using the samples drawn from the posterior distribution. :param dict samples: dictionary of samples keyed by site name. :param bool group_by_chain: If True, each variable in `samples` will be treated as having shape `num_chains x num_samples x sample_shape`. Otherwise, the corresponding shape will be `num_samples x sample_shape` (i.e. without chain dimension). :return: dictionary of diagnostic stats for each sample site. """ diagnostics = {} for site, support in samples.items(): if not group_by_chain: support = support.unsqueeze(0) site_stats = OrderedDict() site_stats["n_eff"] = _safe(stats.effective_sample_size)(support) site_stats["r_hat"] = stats.split_gelman_rubin(support) diagnostics[site] = site_stats return diagnostics
def summary(samples, prob=0.9, group_by_chain=True): """ Returns a summary table displaying diagnostics of ``samples`` from the posterior. The diagnostics displayed are mean, standard deviation, median, the 90% Credibility Interval, :func:`~pyro.ops.stats.effective_sample_size`, :func:`~pyro.ops.stats.split_gelman_rubin`. :param dict samples: dictionary of samples keyed by site name. :param float prob: the probability mass of samples within the credibility interval. :param bool group_by_chain: If True, each variable in `samples` will be treated as having shape `num_chains x num_samples x sample_shape`. Otherwise, the corresponding shape will be `num_samples x sample_shape` (i.e. without chain dimension). """ if not group_by_chain: samples = {k: v.unsqueeze(0) for k, v in samples.items()} summary_dict = {} for name, value in samples.items(): value_flat = torch.reshape(value, (-1,) + value.shape[2:]) mean = value_flat.mean(dim=0) std = value_flat.std(dim=0) median = value_flat.median(dim=0)[0] hpdi = stats.hpdi(value_flat, prob=prob) n_eff = _safe(stats.effective_sample_size)(value) r_hat = stats.split_gelman_rubin(value) hpd_lower = '{:.1f}%'.format(50 * (1 - prob)) hpd_upper = '{:.1f}%'.format(50 * (1 + prob)) summary_dict[name] = OrderedDict([("mean", mean), ("std", std), ("median", median), (hpd_lower, hpdi[0]), (hpd_upper, hpdi[1]), ("n_eff", n_eff), ("r_hat", r_hat)]) return summary_dict def print_summary(samples, prob=0.9, group_by_chain=True): """ Prints a summary table displaying diagnostics of ``samples`` from the posterior. The diagnostics displayed are mean, standard deviation, median, the 90% Credibility Interval, :func:`~pyro.ops.stats.effective_sample_size`, :func:`~pyro.ops.stats.split_gelman_rubin`. :param dict samples: dictionary of samples keyed by site name. :param float prob: the probability mass of samples within the credibility interval. :param bool group_by_chain: If True, each variable in `samples` will be treated as having shape `num_chains x num_samples x sample_shape`. Otherwise, the corresponding shape will be `num_samples x sample_shape` (i.e. without chain dimension). """ summary_dict = summary(samples, prob, group_by_chain) row_names = {k: k + '[' + ','.join(map(lambda x: str(x - 1), v.shape[2:])) + ']' for k, v in samples.items()} max_len = max(max(map(lambda x: len(x), row_names.values())), 10) name_format = '{:>' + str(max_len) + '}' header_format = name_format + ' {:>9}' * 7 columns = [''] + list(list(summary_dict.values())[0].keys()) print() print(header_format.format(*columns)) row_format = name_format + ' {:>9.2f}' * 7 for name, stats_dict in summary_dict.items(): shape = stats_dict["mean"].shape if len(shape) == 0: print(row_format.format(name, *stats_dict.values())) else: for idx in product(*map(range, shape)): idx_str = '[{}]'.format(','.join(map(str, idx))) print(row_format.format(name + idx_str, *[v[idx] for v in stats_dict.values()])) print() def _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, sample_sites, return_trace=False): collected = [] samples = [{k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples)] for i in range(num_samples): trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(*model_args, **model_kwargs) if return_trace: collected.append(trace) else: collected.append({site: trace.nodes[site]['value'] for site in sample_sites}) return collected if return_trace else {site: torch.stack([s[site] for s in collected]) for site in sample_sites} def predictive(model, posterior_samples, *args, **kwargs): """ .. warning:: This function is deprecated and will be removed in a future release. Use the :class:`~pyro.infer.predictive.Predictive` class instead. Run model by sampling latent parameters from `posterior_samples`, and return values at sample sites from the forward run. By default, only sites not contained in `posterior_samples` are returned. This can be modified by changing the `return_sites` keyword argument. :param model: Python callable containing Pyro primitives. :param dict posterior_samples: dictionary of samples from the posterior. :param args: model arguments. :param kwargs: model kwargs; and other keyword arguments (see below). :Keyword Arguments: * **num_samples** (``int``) - number of samples to draw from the predictive distribution. This argument has no effect if ``posterior_samples`` is non-empty, in which case, the leading dimension size of samples in ``posterior_samples`` is used. * **return_sites** (``list``) - sites to return; by default only sample sites not present in `posterior_samples` are returned. * **return_trace** (``bool``) - whether to return the full trace. Note that this is vectorized over `num_samples`. * **parallel** (``bool``) - predict in parallel by wrapping the existing model in an outermost `plate` messenger. Note that this requires that the model has all batch dims correctly annotated via :class:`~pyro.plate`. Default is `False`. :return: dict of samples from the predictive distribution, or a single vectorized `trace` (if `return_trace=True`). """ warnings.warn('The `mcmc.predictive` function is deprecated and will be removed in ' 'a future release. Use the `pyro.infer.Predictive` class instead.', FutureWarning) num_samples = kwargs.pop('num_samples', None) return_sites = kwargs.pop('return_sites', None) return_trace = kwargs.pop('return_trace', False) parallel = kwargs.pop('parallel', False) max_plate_nesting = _guess_max_plate_nesting(model, args, kwargs) model_trace = prune_subsample_sites(poutine.trace(model).get_trace(*args, **kwargs)) reshaped_samples = {} for name, sample in posterior_samples.items(): batch_size, sample_shape = sample.shape[0], sample.shape[1:] if num_samples is None: num_samples = batch_size elif num_samples != batch_size: warnings.warn("Sample's leading dimension size {} is different from the " "provided {} num_samples argument. Defaulting to {}." .format(batch_size, num_samples, batch_size), UserWarning) num_samples = batch_size sample = sample.reshape((num_samples,) + (1,) * (max_plate_nesting - len(sample_shape)) + sample_shape) reshaped_samples[name] = sample if num_samples is None: raise ValueError("No sample sites in model to infer `num_samples`.") return_site_shapes = {} for site in model_trace.stochastic_nodes + model_trace.observation_nodes: site_shape = (num_samples,) + model_trace.nodes[site]['value'].shape if return_sites: if site in return_sites: return_site_shapes[site] = site_shape else: if site not in reshaped_samples: return_site_shapes[site] = site_shape if not parallel: return _predictive_sequential(model, posterior_samples, args, kwargs, num_samples, return_site_shapes.keys(), return_trace) def _vectorized_fn(fn): """ Wraps a callable inside an outermost :class:`~pyro.plate` to parallelize sampling from the posterior predictive. :param fn: arbitrary callable containing Pyro primitives. :return: wrapped callable. """ def wrapped_fn(*args, **kwargs): with pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1): return fn(*args, **kwargs) return wrapped_fn trace = poutine.trace(poutine.condition(_vectorized_fn(model), reshaped_samples))\ .get_trace(*args, **kwargs) if return_trace: return trace predictions = {} for site, shape in return_site_shapes.items(): value = trace.nodes[site]['value'] if value.numel() < reduce((lambda x, y: x * y), shape): predictions[site] = value.expand(shape) else: predictions[site] = value.reshape(shape) return predictions