Source code for pyro.infer.mcmc.util

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import functools
import warnings
from collections import OrderedDict, defaultdict
from functools import partial, reduce
from itertools import product
import traceback as tb

import torch
from torch.distributions import biject_to
from opt_einsum import shared_intermediates

import pyro
import pyro.poutine as poutine
from pyro.distributions.util import broadcast_shape, logsumexp
from pyro.infer import config_enumerate
from pyro.infer.autoguide.initialization import InitMessenger, init_to_uniform
from pyro.infer.util import is_validation_enabled
from pyro.ops import stats
from pyro.ops.contract import contract_to_tensor
from pyro.ops.integrator import potential_grad
from pyro.poutine.subsample_messenger import _Subsample
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_site_shape, ignore_jit_warnings


class TraceTreeEvaluator:
    """
    Computes the log probability density of a trace (of a model with
    tree structure) that possibly contains discrete sample sites
    enumerated in parallel. This will be deprecated in favor of
    :class:`~pyro.infer.mcmc.util.EinsumTraceProbEvaluator`.

    :param model_trace: execution trace from a static model.
    :param bool has_enumerable_sites: whether the trace contains any
        discrete enumerable sites.
    :param int max_plate_nesting: Optional bound on max number of nested
        :func:`pyro.plate` contexts.
    """
    def __init__(self,
                 model_trace,
                 has_enumerable_sites=False,
                 max_plate_nesting=None):
        self.has_enumerable_sites = has_enumerable_sites
        self.max_plate_nesting = max_plate_nesting
        # To be populated using the model trace once.
        self._log_probs = defaultdict(list)
        self._log_prob_shapes = defaultdict(tuple)
        self._children = defaultdict(list)
        self._enum_dims = {}
        self._plate_dims = {}
        self._parse_model_structure(model_trace)

    def _parse_model_structure(self, model_trace):
        if not self.has_enumerable_sites:
            return
        if self.max_plate_nesting is None:
            raise ValueError("Finite value required for `max_plate_nesting` when model "
                             "has discrete (enumerable) sites.")
        self._compute_log_prob_terms(model_trace)
        # 1. Infer model structure - compute parent-child relationship.
        sorted_ordinals = sorted(self._log_probs.keys())
        for i, child_node in enumerate(sorted_ordinals):
            for j in range(i-1, -1, -1):
                cur_node = sorted_ordinals[j]
                if cur_node < child_node:
                    self._children[cur_node].append(child_node)
                    break  # at most 1 parent.
        # 2. Populate `plate_dims` and `enum_dims` to be evaluated/
        #    enumerated out at each ordinal.
        self._populate_cache(frozenset(), frozenset(), set())

    def _populate_cache(self, ordinal, parent_ordinal, parent_enum_dims):
        """
        For each ordinal, populate the `plate` and `enum` dims to be
        evaluated or enumerated out.
        """
        log_prob_shape = self._log_prob_shapes[ordinal]
        plate_dims = sorted([frame.dim for frame in ordinal - parent_ordinal])
        enum_dims = set((i for i in range(-len(log_prob_shape), -self.max_plate_nesting)
                         if log_prob_shape[i] > 1))
        self._plate_dims[ordinal] = plate_dims
        self._enum_dims[ordinal] = set(enum_dims - parent_enum_dims)
        for c in self._children[ordinal]:
            self._populate_cache(c, ordinal, enum_dims)

    def _compute_log_prob_terms(self, model_trace):
        """
        Computes the conditional probabilities for each of the sites
        in the model trace, and stores the result in `self._log_probs`.
        """
        model_trace.compute_log_prob()
        self._log_probs = defaultdict(list)
        ordering = {name: frozenset(site["cond_indep_stack"])
                    for name, site in model_trace.nodes.items()
                    if site["type"] == "sample"}
        # Collect log prob terms per independence context.
        for name, site in model_trace.nodes.items():
            if site["type"] == "sample":
                if is_validation_enabled():
                    check_site_shape(site, self.max_plate_nesting)
                self._log_probs[ordering[name]].append(site["log_prob"])
        if not self._log_prob_shapes:
            for ordinal, log_prob in self._log_probs.items():
                self._log_prob_shapes[ordinal] = broadcast_shape(*(t.shape for t in self._log_probs[ordinal]))

    def _reduce(self, ordinal, agg_log_prob=torch.tensor(0.)):
        """
        Reduce the log prob terms for the given ordinal:
          - taking log_sum_exp of factors in enum dims (i.e.
            adding up the probability terms).
          - summing up the dims within `max_plate_nesting`.
            (i.e. multiplying probs within independent batches).

        :param ordinal: node (ordinal)
        :param torch.Tensor agg_log_prob: aggregated `log_prob`
            terms from the downstream nodes.
        :return: `log_prob` with marginalized `plate` and `enum`
            dims.
        """
        log_prob = sum(self._log_probs[ordinal]) + agg_log_prob
        for enum_dim in self._enum_dims[ordinal]:
            log_prob = logsumexp(log_prob, dim=enum_dim, keepdim=True)
        for marginal_dim in self._plate_dims[ordinal]:
            log_prob = log_prob.sum(dim=marginal_dim, keepdim=True)
        return log_prob

    def _aggregate_log_probs(self, ordinal):
        """
        Aggregate the `log_prob` terms using depth first search.
        """
        if not self._children[ordinal]:
            return self._reduce(ordinal)
        agg_log_prob = sum(map(self._aggregate_log_probs, self._children[ordinal]))
        return self._reduce(ordinal, agg_log_prob)

    def log_prob(self, model_trace):
        """
        Returns the log pdf of `model_trace` by appropriately handling
        enumerated log prob factors.

        :return: log pdf of the trace.
        """
        with shared_intermediates():
            if not self.has_enumerable_sites:
                return model_trace.log_prob_sum()
            self._compute_log_prob_terms(model_trace)
            return self._aggregate_log_probs(ordinal=frozenset()).sum()


class TraceEinsumEvaluator:
    """
    Computes the log probability density of a trace (of a model with
    tree structure) that possibly contains discrete sample sites
    enumerated in parallel. This uses optimized `einsum` operations
    to marginalize out the the enumerated dimensions in the trace
    via :class:`~pyro.ops.contract.contract_to_tensor`.

    :param model_trace: execution trace from a static model.
    :param bool has_enumerable_sites: whether the trace contains any
        discrete enumerable sites.
    :param int max_plate_nesting: Optional bound on max number of nested
        :func:`pyro.plate` contexts.
    """
    def __init__(self,
                 model_trace,
                 has_enumerable_sites=False,
                 max_plate_nesting=None):
        self.has_enumerable_sites = has_enumerable_sites
        self.max_plate_nesting = max_plate_nesting
        # To be populated using the model trace once.
        self._enum_dims = set()
        self.ordering = {}
        self._populate_cache(model_trace)

    def _populate_cache(self, model_trace):
        """
        Populate the ordinals (set of ``CondIndepStack`` frames)
        and enum_dims for each sample site.
        """
        if not self.has_enumerable_sites:
            return
        if self.max_plate_nesting is None:
            raise ValueError("Finite value required for `max_plate_nesting` when model "
                             "has discrete (enumerable) sites.")
        model_trace.compute_log_prob()
        model_trace.pack_tensors()
        for name, site in model_trace.nodes.items():
            if site["type"] == "sample" and not isinstance(site["fn"], _Subsample):
                if is_validation_enabled():
                    check_site_shape(site, self.max_plate_nesting)
                self.ordering[name] = frozenset(model_trace.plate_to_symbol[f.name]
                                                for f in site["cond_indep_stack"]
                                                if f.vectorized)
        self._enum_dims = set(model_trace.symbol_to_dim) - set(model_trace.plate_to_symbol.values())

    def _get_log_factors(self, model_trace):
        """
        Aggregates the `log_prob` terms into a list for each
        ordinal.
        """
        model_trace.compute_log_prob()
        model_trace.pack_tensors()
        log_probs = OrderedDict()
        # Collect log prob terms per independence context.
        for name, site in model_trace.nodes.items():
            if site["type"] == "sample" and not isinstance(site["fn"], _Subsample):
                if is_validation_enabled():
                    check_site_shape(site, self.max_plate_nesting)
                log_probs.setdefault(self.ordering[name], []).append(site["packed"]["log_prob"])
        return log_probs

    def log_prob(self, model_trace):
        """
        Returns the log pdf of `model_trace` by appropriately handling
        enumerated log prob factors.

        :return: log pdf of the trace.
        """
        if not self.has_enumerable_sites:
            return model_trace.log_prob_sum()
        log_probs = self._get_log_factors(model_trace)
        with shared_intermediates() as cache:
            return contract_to_tensor(log_probs, self._enum_dims, cache=cache)


def _guess_max_plate_nesting(model, args, kwargs):
    """
    Guesses max_plate_nesting by running the model once
    without enumeration. This optimistically assumes static model
    structure.
    """
    with poutine.block():
        model_trace = poutine.trace(model).get_trace(*args, **kwargs)
    sites = [site for site in model_trace.nodes.values()
             if site["type"] == "sample"]

    dims = [frame.dim
            for site in sites
            for frame in site["cond_indep_stack"]
            if frame.vectorized]
    max_plate_nesting = -min(dims) if dims else 0
    return max_plate_nesting


class _PEMaker:
    def __init__(self, model, model_args, model_kwargs, trace_prob_evaluator, transforms):
        self.model = model
        self.model_args = model_args
        self.model_kwargs = model_kwargs
        self.trace_prob_evaluator = trace_prob_evaluator
        self.transforms = transforms
        self._compiled_fn = None

    def _potential_fn(self, params):
        params_constrained = {k: self.transforms[k].inv(v) for k, v in params.items()}
        cond_model = poutine.condition(self.model, params_constrained)
        model_trace = poutine.trace(cond_model).get_trace(*self.model_args,
                                                          **self.model_kwargs)
        log_joint = self.trace_prob_evaluator.log_prob(model_trace)
        for name, t in self.transforms.items():
            log_joint = log_joint - torch.sum(
                t.log_abs_det_jacobian(params_constrained[name], params[name]))
        return -log_joint

    def _potential_fn_jit(self, skip_jit_warnings, jit_options, params):
        if not params:
            return self._potential_fn(params)
        names, vals = zip(*sorted(params.items()))

        if self._compiled_fn:
            return self._compiled_fn(*vals)

        with pyro.validation_enabled(False):
            tmp = []
            for _, v in pyro.get_param_store().named_parameters():
                if v.requires_grad:
                    v.requires_grad_(False)
                    tmp.append(v)

            def _pe_jit(*zi):
                params = dict(zip(names, zi))
                return self._potential_fn(params)

            if skip_jit_warnings:
                _pe_jit = ignore_jit_warnings()(_pe_jit)
            self._compiled_fn = torch.jit.trace(_pe_jit, vals, **jit_options)

            result = self._compiled_fn(*vals)
            for v in tmp:
                v.requires_grad_(True)
            return result

    def get_potential_fn(self, jit_compile=False, skip_jit_warnings=True, jit_options=None):
        if jit_compile:
            jit_options = {"check_trace": False} if jit_options is None else jit_options
            return partial(self._potential_fn_jit, skip_jit_warnings, jit_options)
        return self._potential_fn


def _find_valid_initial_params(model, model_args, model_kwargs, transforms, potential_fn,
                               prototype_params, max_tries_initial_params=100, num_chains=1,
                               init_strategy=init_to_uniform, trace=None):
    params = prototype_params

    # For empty models, exit early
    if not params:
        return params

    params_per_chain = defaultdict(list)
    num_found = 0
    model = InitMessenger(init_strategy)(model)
    for attempt in range(num_chains * max_tries_initial_params):
        if trace is None:
            trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
        samples = {name: trace.nodes[name]["value"].detach() for name in params}
        params = {k: transforms[k](v) for k, v in samples.items()}
        pe_grad, pe = potential_grad(potential_fn, params)

        if torch.isfinite(pe) and all(map(torch.all, map(torch.isfinite, pe_grad.values()))):
            for k, v in params.items():
                params_per_chain[k].append(v)
            num_found += 1
            if num_found == num_chains:
                if num_chains == 1:
                    return {k: v[0] for k, v in params_per_chain.items()}
                else:
                    return {k: torch.stack(v) for k, v in params_per_chain.items()}
        trace = None
    raise ValueError("Model specification seems incorrect - cannot find valid initial params.")


[docs]def initialize_model(model, model_args=(), model_kwargs={}, transforms=None, max_plate_nesting=None, jit_compile=False, jit_options=None, skip_jit_warnings=False, num_chains=1, init_strategy=init_to_uniform, initial_params=None): """ Given a Python callable with Pyro primitives, generates the following model-specific properties needed for inference using HMC/NUTS kernels: - initial parameters to be sampled using a HMC kernel, - a potential function whose input is a dict of parameters in unconstrained space, - transforms to transform latent sites of `model` to unconstrained space, - a prototype trace to be used in MCMC to consume traces from sampled parameters. :param model: a Pyro model which contains Pyro primitives. :param tuple model_args: optional args taken by `model`. :param dict model_kwargs: optional kwargs taken by `model`. :param dict transforms: Optional dictionary that specifies a transform for a sample site with constrained support to unconstrained space. The transform should be invertible, and implement `log_abs_det_jacobian`. If not specified and the model has sites with constrained support, automatic transformations will be applied, as specified in :mod:`torch.distributions.constraint_registry`. :param int max_plate_nesting: Optional bound on max number of nested :func:`pyro.plate` contexts. This is required if model contains discrete sample sites that can be enumerated over in parallel. :param bool jit_compile: Optional parameter denoting whether to use the PyTorch JIT to trace the log density computation, and use this optimized executable trace in the integrator. :param dict jit_options: A dictionary contains optional arguments for :func:`torch.jit.trace` function. :param bool ignore_jit_warnings: Flag to ignore warnings from the JIT tracer when ``jit_compile=True``. Default is False. :param int num_chains: Number of parallel chains. If `num_chains > 1`, the returned `initial_params` will be a list with `num_chains` elements. :param callable init_strategy: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. :param dict initial_params: dict containing initial tensors in unconstrained space to initiate the markov chain. :returns: a tuple of (`initial_params`, `potential_fn`, `transforms`, `prototype_trace`) """ # XXX `transforms` domains are sites' supports # FIXME: find a good pattern to deal with `transforms` arg if transforms is None: automatic_transform_enabled = True transforms = {} else: automatic_transform_enabled = False if max_plate_nesting is None: max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) # Wrap model in `poutine.enum` to enumerate over discrete latent sites. # No-op if model does not have any discrete latents. model = poutine.enum(config_enumerate(model), first_available_dim=-1 - max_plate_nesting) prototype_model = poutine.trace(InitMessenger(init_strategy)(model)) model_trace = prototype_model.get_trace(*model_args, **model_kwargs) has_enumerable_sites = False prototype_samples = {} for name, node in model_trace.iter_stochastic_nodes(): fn = node["fn"] if isinstance(fn, _Subsample): if fn.subsample_size is not None and fn.subsample_size < fn.size: raise NotImplementedError("HMC/NUTS does not support model with subsample sites.") continue if node["fn"].has_enumerate_support: has_enumerable_sites = True continue # we need to detach here because this sample can be a leaf variable, # so we can't change its requires_grad flag to calculate its grad in # velocity_verlet prototype_samples[name] = node["value"].detach() if automatic_transform_enabled: transforms[name] = biject_to(node["fn"].support).inv trace_prob_evaluator = TraceEinsumEvaluator(model_trace, has_enumerable_sites, max_plate_nesting) pe_maker = _PEMaker(model, model_args, model_kwargs, trace_prob_evaluator, transforms) if initial_params is None: prototype_params = {k: transforms[k](v) for k, v in prototype_samples.items()} # Note that we deliberately do not exercise jit compilation here so as to # enable potential_fn to be picklable (a torch._C.Function cannot be pickled). # We pass model_trace merely for computational savings. initial_params = _find_valid_initial_params(model, model_args, model_kwargs, transforms, pe_maker.get_potential_fn(), prototype_params, num_chains=num_chains, init_strategy=init_strategy, trace=model_trace) potential_fn = pe_maker.get_potential_fn(jit_compile, skip_jit_warnings, jit_options) return initial_params, potential_fn, transforms, model_trace
def _safe(fn): """ Safe version of utilities in the :mod:`pyro.ops.stats` module. Wrapped functions return `NaN` tensors instead of throwing exceptions. :param fn: stats function from :mod:`pyro.ops.stats` module. """ @functools.wraps(fn) def wrapped(sample, *args, **kwargs): try: val = fn(sample, *args, **kwargs) except Exception: warnings.warn(tb.format_exc()) val = torch.full(sample.shape[2:], float("nan"), dtype=sample.dtype, device=sample.device) return val return wrapped
[docs]def diagnostics(samples, group_by_chain=True): """ Gets diagnostics statistics such as effective sample size and split Gelman-Rubin using the samples drawn from the posterior distribution. :param dict samples: dictionary of samples keyed by site name. :param bool group_by_chain: If True, each variable in `samples` will be treated as having shape `num_chains x num_samples x sample_shape`. Otherwise, the corresponding shape will be `num_samples x sample_shape` (i.e. without chain dimension). :return: dictionary of diagnostic stats for each sample site. """ diagnostics = {} for site, support in samples.items(): if not group_by_chain: support = support.unsqueeze(0) site_stats = OrderedDict() site_stats["n_eff"] = _safe(stats.effective_sample_size)(support) site_stats["r_hat"] = stats.split_gelman_rubin(support) diagnostics[site] = site_stats return diagnostics
def summary(samples, prob=0.9, group_by_chain=True): """ Returns a summary table displaying diagnostics of ``samples`` from the posterior. The diagnostics displayed are mean, standard deviation, median, the 90% Credibility Interval, :func:`~pyro.ops.stats.effective_sample_size`, :func:`~pyro.ops.stats.split_gelman_rubin`. :param dict samples: dictionary of samples keyed by site name. :param float prob: the probability mass of samples within the credibility interval. :param bool group_by_chain: If True, each variable in `samples` will be treated as having shape `num_chains x num_samples x sample_shape`. Otherwise, the corresponding shape will be `num_samples x sample_shape` (i.e. without chain dimension). """ if not group_by_chain: samples = {k: v.unsqueeze(0) for k, v in samples.items()} summary_dict = {} for name, value in samples.items(): value_flat = torch.reshape(value, (-1,) + value.shape[2:]) mean = value_flat.mean(dim=0) std = value_flat.std(dim=0) median = value_flat.median(dim=0)[0] hpdi = stats.hpdi(value_flat, prob=prob) n_eff = _safe(stats.effective_sample_size)(value) r_hat = stats.split_gelman_rubin(value) hpd_lower = '{:.1f}%'.format(50 * (1 - prob)) hpd_upper = '{:.1f}%'.format(50 * (1 + prob)) summary_dict[name] = OrderedDict([("mean", mean), ("std", std), ("median", median), (hpd_lower, hpdi[0]), (hpd_upper, hpdi[1]), ("n_eff", n_eff), ("r_hat", r_hat)]) return summary_dict def print_summary(samples, prob=0.9, group_by_chain=True): """ Prints a summary table displaying diagnostics of ``samples`` from the posterior. The diagnostics displayed are mean, standard deviation, median, the 90% Credibility Interval, :func:`~pyro.ops.stats.effective_sample_size`, :func:`~pyro.ops.stats.split_gelman_rubin`. :param dict samples: dictionary of samples keyed by site name. :param float prob: the probability mass of samples within the credibility interval. :param bool group_by_chain: If True, each variable in `samples` will be treated as having shape `num_chains x num_samples x sample_shape`. Otherwise, the corresponding shape will be `num_samples x sample_shape` (i.e. without chain dimension). """ if len(samples) == 0: return summary_dict = summary(samples, prob, group_by_chain) row_names = {k: k + '[' + ','.join(map(lambda x: str(x - 1), v.shape[2:])) + ']' for k, v in samples.items()} max_len = max(max(map(lambda x: len(x), row_names.values())), 10) name_format = '{:>' + str(max_len) + '}' header_format = name_format + ' {:>9}' * 7 columns = [''] + list(list(summary_dict.values())[0].keys()) print() print(header_format.format(*columns)) row_format = name_format + ' {:>9.2f}' * 7 for name, stats_dict in summary_dict.items(): shape = stats_dict["mean"].shape if len(shape) == 0: print(row_format.format(name, *stats_dict.values())) else: for idx in product(*map(range, shape)): idx_str = '[{}]'.format(','.join(map(str, idx))) print(row_format.format(name + idx_str, *[v[idx] for v in stats_dict.values()])) print() def _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, sample_sites, return_trace=False): collected = [] samples = [{k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples)] for i in range(num_samples): trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(*model_args, **model_kwargs) if return_trace: collected.append(trace) else: collected.append({site: trace.nodes[site]['value'] for site in sample_sites}) return collected if return_trace else {site: torch.stack([s[site] for s in collected]) for site in sample_sites} def predictive(model, posterior_samples, *args, **kwargs): """ .. warning:: This function is deprecated and will be removed in a future release. Use the :class:`~pyro.infer.predictive.Predictive` class instead. Run model by sampling latent parameters from `posterior_samples`, and return values at sample sites from the forward run. By default, only sites not contained in `posterior_samples` are returned. This can be modified by changing the `return_sites` keyword argument. :param model: Python callable containing Pyro primitives. :param dict posterior_samples: dictionary of samples from the posterior. :param args: model arguments. :param kwargs: model kwargs; and other keyword arguments (see below). :Keyword Arguments: * **num_samples** (``int``) - number of samples to draw from the predictive distribution. This argument has no effect if ``posterior_samples`` is non-empty, in which case, the leading dimension size of samples in ``posterior_samples`` is used. * **return_sites** (``list``) - sites to return; by default only sample sites not present in `posterior_samples` are returned. * **return_trace** (``bool``) - whether to return the full trace. Note that this is vectorized over `num_samples`. * **parallel** (``bool``) - predict in parallel by wrapping the existing model in an outermost `plate` messenger. Note that this requires that the model has all batch dims correctly annotated via :class:`~pyro.plate`. Default is `False`. :return: dict of samples from the predictive distribution, or a single vectorized `trace` (if `return_trace=True`). """ warnings.warn('The `mcmc.predictive` function is deprecated and will be removed in ' 'a future release. Use the `pyro.infer.Predictive` class instead.', FutureWarning) num_samples = kwargs.pop('num_samples', None) return_sites = kwargs.pop('return_sites', None) return_trace = kwargs.pop('return_trace', False) parallel = kwargs.pop('parallel', False) max_plate_nesting = _guess_max_plate_nesting(model, args, kwargs) model_trace = prune_subsample_sites(poutine.trace(model).get_trace(*args, **kwargs)) reshaped_samples = {} for name, sample in posterior_samples.items(): batch_size, sample_shape = sample.shape[0], sample.shape[1:] if num_samples is None: num_samples = batch_size elif num_samples != batch_size: warnings.warn("Sample's leading dimension size {} is different from the " "provided {} num_samples argument. Defaulting to {}." .format(batch_size, num_samples, batch_size), UserWarning) num_samples = batch_size sample = sample.reshape((num_samples,) + (1,) * (max_plate_nesting - len(sample_shape)) + sample_shape) reshaped_samples[name] = sample if num_samples is None: raise ValueError("No sample sites in model to infer `num_samples`.") return_site_shapes = {} for site in model_trace.stochastic_nodes + model_trace.observation_nodes: site_shape = (num_samples,) + model_trace.nodes[site]['value'].shape if return_sites: if site in return_sites: return_site_shapes[site] = site_shape else: if site not in reshaped_samples: return_site_shapes[site] = site_shape if not parallel: return _predictive_sequential(model, posterior_samples, args, kwargs, num_samples, return_site_shapes.keys(), return_trace) def _vectorized_fn(fn): """ Wraps a callable inside an outermost :class:`~pyro.plate` to parallelize sampling from the posterior predictive. :param fn: arbitrary callable containing Pyro primitives. :return: wrapped callable. """ def wrapped_fn(*args, **kwargs): with pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1): return fn(*args, **kwargs) return wrapped_fn trace = poutine.trace(poutine.condition(_vectorized_fn(model), reshaped_samples))\ .get_trace(*args, **kwargs) if return_trace: return trace predictions = {} for site, shape in return_site_shapes.items(): value = trace.nodes[site]['value'] if value.numel() < reduce((lambda x, y: x * y), shape): predictions[site] = value.expand(shape) else: predictions[site] = value.reshape(shape) return predictions