Source code for pyro.infer.discrete

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import functools
import warnings
from collections import OrderedDict

from opt_einsum import shared_intermediates

import pyro.ops.packed as packed
from pyro import poutine
from pyro.infer.traceenum_elbo import TraceEnum_ELBO
from pyro.ops.contract import contract_tensor_tree
from pyro.ops.einsum.adjoint import require_backward
from pyro.ops.rings import MapRing, SampleRing
from pyro.poutine.enum_messenger import EnumMessenger
from pyro.poutine.replay_messenger import ReplayMessenger
from pyro.poutine.util import prune_subsample_sites
from pyro.util import jit_iter

_RINGS = {0: MapRing, 1: SampleRing}

def _make_ring(temperature, cache, dim_to_size):
        return _RINGS[temperature](cache=cache, dim_to_size=dim_to_size)
    except KeyError as e:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now") from e

class SamplePosteriorMessenger(ReplayMessenger):
    # This acts like ReplayMessenger but additionally replays cond_indep_stack.

    def _pyro_sample(self, msg):
        if msg["infer"].get("enumerate") == "parallel":
        if msg["name"] in self.trace:
            msg["cond_indep_stack"] = self.trace.nodes[msg["name"]]["cond_indep_stack"]

def _sample_posterior(
    model, first_available_dim, temperature, strict_enumeration_warning, *args, **kwargs
    # For internal use by infer_discrete.

    # Create an enumerated trace.
    with poutine.block(), EnumMessenger(first_available_dim):
        enum_trace = poutine.trace(model).get_trace(*args, **kwargs)
    enum_trace = prune_subsample_sites(enum_trace)

    return _sample_posterior_from_trace(
        model, enum_trace, temperature, strict_enumeration_warning, *args, **kwargs

def _sample_posterior_from_trace(
    model, enum_trace, temperature, strict_enumeration_warning, *args, **kwargs
    plate_to_symbol = enum_trace.plate_to_symbol

    # Collect a set of query sample sites to which the backward algorithm will propagate.
    sum_dims = set()
    queries = []
    dim_to_size = {}
    cost_terms = OrderedDict()
    enum_terms = OrderedDict()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample":
            ordinal = frozenset(
                for f in node["cond_indep_stack"]
                if f.vectorized and f.size > 1
            # For sites that depend on an enumerated variable, we need to apply
            # the mask but not the scale when sampling.
            if "masked_log_prob" not in node["packed"]:
                node["packed"]["masked_log_prob"] = packed.scale_and_mask(
                    node["packed"]["unscaled_log_prob"], mask=node["packed"]["mask"]
            log_prob = node["packed"]["masked_log_prob"]
            sum_dims.update(frozenset(log_prob._pyro_dims) - ordinal)
            if sum_dims.isdisjoint(log_prob._pyro_dims):
            dim_to_size.update(zip(log_prob._pyro_dims, log_prob.shape))
            if node["infer"].get("_enumerate_dim") is None:
                cost_terms.setdefault(ordinal, []).append(log_prob)
                enum_terms.setdefault(ordinal, []).append(log_prob)
            # Note we mark all sample sites with require_backward to gather
            # enumerated sites and adjust cond_indep_stack of all sample sites.
            if not node["is_observed"]:

    if strict_enumeration_warning and not enum_terms:
            "infer_discrete found no sample sites configured for enumeration. "
            "If you want to enumerate sites, you need to @config_enumerate or set "
            'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}.'

    # We take special care to match the term ordering in
    # pyro.infer.traceenum_elbo._compute_model_factors() to allow
    # contract_tensor_tree() to use shared_intermediates() inside
    # TraceEnumSample_ELBO. The special ordering is: first all cost terms in
    # order of model_trace, then all enum_terms in order of model trace.
    log_probs = cost_terms
    for ordinal, terms in enum_terms.items():
        log_probs.setdefault(ordinal, []).extend(terms)

    # Run forward-backward algorithm, collecting the ordinal of each connected component.
    cache = getattr(enum_trace, "_sharing_cache", {})
    ring = _make_ring(temperature, cache, dim_to_size)
    with shared_intermediates(cache):
        log_probs = contract_tensor_tree(
            log_probs, sum_dims, ring=ring
        )  # run forward algorithm
    query_to_ordinal = {}
    pending = object()  # a constant value for pending queries
    for query in queries:
        query._pyro_backward_result = pending
    for ordinal, terms in log_probs.items():
        for term in terms:
            if hasattr(term, "_pyro_backward"):
                term._pyro_backward()  # run backward algorithm
        # Note: this is quadratic in number of ordinals
        for query in queries:
            if (
                query not in query_to_ordinal
                and query._pyro_backward_result is not pending
                query_to_ordinal[query] = ordinal

    # Construct a collapsed trace by gathering and adjusting cond_indep_stack.
    collapsed_trace = poutine.Trace()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample" and not node["is_observed"]:
            # TODO move this into a Leaf implementation somehow
            new_node = {
                "type": "sample",
                "name": node["name"],
                "is_observed": False,
                "infer": node["infer"].copy(),
                "cond_indep_stack": node["cond_indep_stack"],
                "value": node["value"],
            log_prob = node["packed"]["masked_log_prob"]
            if hasattr(log_prob, "_pyro_backward_result"):
                # Adjust the cond_indep_stack.
                ordinal = query_to_ordinal[log_prob]
                new_node["cond_indep_stack"] = tuple(
                    for f in node["cond_indep_stack"]
                    if not (f.vectorized and f.size > 1)
                    or plate_to_symbol[] in ordinal

                # Gather if node depended on an enumerated value.
                sample = log_prob._pyro_backward_result
                if sample is not None:
                    new_value = packed.pack(
                        node["value"], node["infer"]["_dim_to_symbol"]
                    for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims):
                        if dim in new_value._pyro_dims:
                            index._pyro_dims = sample._pyro_dims[1:]
                            new_value = packed.gather(new_value, index, dim)
                    new_node["value"] = packed.unpack(
                        new_value, enum_trace.symbol_to_dim

            collapsed_trace.add_node(node["name"], **new_node)

    # Replay the model against the collapsed trace.
    with SamplePosteriorMessenger(trace=collapsed_trace):
        return model(*args, **kwargs)

[docs]def infer_discrete( fn=None, first_available_dim=None, temperature=1, *, strict_enumeration_warning=True ): """ A poutine that samples discrete sites marked with ``site["infer"]["enumerate"] = "parallel"`` from the posterior, conditioned on observations. Example:: @infer_discrete(first_available_dim=-1, temperature=0) @config_enumerate def viterbi_decoder(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim) means = torch.arange(float(hidden_dim)) states = [0] for t in pyro.markov(range(len(data))): states.append(pyro.sample("states_{}".format(t), dist.Categorical(transition[states[-1]]))) pyro.sample("obs_{}".format(t), dist.Normal(means[states[-1]], 1.), obs=data[t]) return states # returns maximum likelihood states .. warning: The ``log_prob``s of the inferred model's trace are not meaningful, and may be changed future release. :param fn: a stochastic function (callable containing Pyro primitive calls) :param int first_available_dim: The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer. :param int temperature: Either 1 (sample via forward-filter backward-sample) or 0 (optimize via Viterbi-like MAP inference). Defaults to 1 (sample). :param bool strict_enumeration_warning: Whether to warn in case no enumerated sample sites are found. Defalts to True. """ assert first_available_dim < 0, first_available_dim if fn is None: # support use as a decorator return functools.partial( infer_discrete, first_available_dim=first_available_dim, temperature=temperature, ) return functools.partial( _sample_posterior, fn, first_available_dim, temperature, strict_enumeration_warning, )
[docs]class TraceEnumSample_ELBO(TraceEnum_ELBO): """ This extends :class:`TraceEnum_ELBO` to make it cheaper to sample from discrete latent states during SVI. The following are equivalent but the first is cheaper, sharing work between the computations of ``loss`` and ``z``:: # Version 1. elbo = TraceEnumSample_ELBO(max_plate_nesting=1) loss = elbo.loss(*args, **kwargs) z = elbo.sample_saved() # Version 2. elbo = TraceEnum_ELBO(max_plate_nesting=1) loss = elbo.loss(*args, **kwargs) guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) z = infer_discrete(poutine.replay(model, guide_trace), first_available_dim=-2)(*args, **kwargs) """ def _get_trace(self, model, guide, args, kwargs): model_trace, guide_trace = super()._get_trace(model, guide, args, kwargs) # Mark all sample sites with require_backward to gather enumerated # sites and adjust cond_indep_stack of all sample sites. for node in model_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: log_prob = node["packed"]["unscaled_log_prob"] require_backward(log_prob) self._saved_state = model, model_trace, guide_trace, args, kwargs return model_trace, guide_trace
[docs] def sample_saved(self): """ Generate latent samples while reusing work from SVI.step(). """ model, model_trace, guide_trace, args, kwargs = self._saved_state model = poutine.replay(model, guide_trace) temperature = 1 return _sample_posterior_from_trace( model, model_trace, temperature, self.strict_enumeration_warning, *args, **kwargs )