Source code for pyro.infer.tracegraph_elbo

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import weakref
from collections import defaultdict
from operator import itemgetter

import torch

import pyro
import pyro.ops.jit
from pyro.distributions.util import detach, is_identically_zero
from pyro.infer import ELBO
from pyro.infer.enum import get_importance_trace
from pyro.infer.util import (
    MultiFrameTensor,
    get_plate_stacks,
    is_validation_enabled,
    torch_backward,
    torch_item,
)
from pyro.ops.provenance import detach_provenance, get_provenance, track_provenance
from pyro.poutine.messenger import Messenger
from pyro.poutine.subsample_messenger import _Subsample
from pyro.util import check_if_enumerated, warn_if_nan


def _get_baseline_options(site):
    """
    Extracts baseline options from ``site["infer"]["baseline"]``.
    """
    # XXX default for baseline_beta currently set here
    options_dict = site["infer"].get("baseline", {}).copy()
    options_tuple = (
        options_dict.pop("nn_baseline", None),
        options_dict.pop("nn_baseline_input", None),
        options_dict.pop("use_decaying_avg_baseline", False),
        options_dict.pop("baseline_beta", 0.90),
        options_dict.pop("baseline_value", None),
    )
    if options_dict:
        raise ValueError(
            "Unrecognized baseline options: {}".format(options_dict.keys())
        )
    return options_tuple


def _construct_baseline(node, guide_site, downstream_cost):
    # XXX should the average baseline be in the param store as below?

    baseline = 0.0
    baseline_loss = 0.0

    (
        nn_baseline,
        nn_baseline_input,
        use_decaying_avg_baseline,
        baseline_beta,
        baseline_value,
    ) = _get_baseline_options(guide_site)

    use_nn_baseline = nn_baseline is not None
    use_baseline_value = baseline_value is not None

    use_baseline = use_nn_baseline or use_decaying_avg_baseline or use_baseline_value

    assert not (
        use_nn_baseline and use_baseline_value
    ), "cannot use baseline_value and nn_baseline simultaneously"
    if use_decaying_avg_baseline:
        dc_shape = downstream_cost.shape
        param_name = "__baseline_avg_downstream_cost_" + node
        with torch.no_grad():
            avg_downstream_cost_old = pyro.param(
                param_name, torch.zeros(dc_shape, device=guide_site["value"].device)
            )
            avg_downstream_cost_new = (
                1 - baseline_beta
            ) * downstream_cost + baseline_beta * avg_downstream_cost_old
        pyro.get_param_store()[param_name] = avg_downstream_cost_new
        baseline += avg_downstream_cost_old
    if use_nn_baseline:
        # block nn_baseline_input gradients except in baseline loss
        baseline += nn_baseline(detach(nn_baseline_input))
    elif use_baseline_value:
        # it's on the user to make sure baseline_value tape only points to baseline params
        baseline += baseline_value
    if use_nn_baseline or use_baseline_value:
        # accumulate baseline loss
        baseline_loss += torch.pow(downstream_cost.detach() - baseline, 2.0).sum()

    if use_baseline:
        if downstream_cost.shape != baseline.shape:
            raise ValueError(
                "Expected baseline at site {} to be {} instead got {}".format(
                    node, downstream_cost.shape, baseline.shape
                )
            )

    return use_baseline, baseline_loss, baseline


def _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes):  #
    # recursively compute downstream cost nodes for all sample sites in model and guide
    # (even though ultimately just need for non-reparameterizable sample sites)
    # 1. downstream costs used for rao-blackwellization
    # 2. model observe sites (as well as terms that arise from the model and guide having different
    # dependency structures) are taken care of via 'children_in_model' below
    topo_sort_guide_nodes = guide_trace.topological_sort(reverse=True)
    topo_sort_guide_nodes = [
        x for x in topo_sort_guide_nodes if guide_trace.nodes[x]["type"] == "sample"
    ]
    ordered_guide_nodes_dict = {n: i for i, n in enumerate(topo_sort_guide_nodes)}

    downstream_guide_cost_nodes = {}
    downstream_costs = {}
    stacks = get_plate_stacks(model_trace)

    for node in topo_sort_guide_nodes:
        downstream_costs[node] = MultiFrameTensor(
            (
                stacks[node],
                model_trace.nodes[node]["log_prob"]
                - guide_trace.nodes[node]["log_prob"],
            )
        )
        nodes_included_in_sum = set([node])
        downstream_guide_cost_nodes[node] = set([node])
        # make more efficient by ordering children appropriately (higher children first)
        children = [
            (k, -ordered_guide_nodes_dict[k]) for k in guide_trace.successors(node)
        ]
        sorted_children = sorted(children, key=itemgetter(1))
        for child, _ in sorted_children:
            child_cost_nodes = downstream_guide_cost_nodes[child]
            downstream_guide_cost_nodes[node].update(child_cost_nodes)
            if nodes_included_in_sum.isdisjoint(child_cost_nodes):  # avoid duplicates
                downstream_costs[node].add(*downstream_costs[child].items())
                # XXX nodes_included_in_sum logic could be more fine-grained, possibly leading
                # to speed-ups in case there are many duplicates
                nodes_included_in_sum.update(child_cost_nodes)
        missing_downstream_costs = (
            downstream_guide_cost_nodes[node] - nodes_included_in_sum
        )
        # include terms we missed because we had to avoid duplicates
        for missing_node in missing_downstream_costs:
            downstream_costs[node].add(
                (
                    stacks[missing_node],
                    model_trace.nodes[missing_node]["log_prob"]
                    - guide_trace.nodes[missing_node]["log_prob"],
                )
            )

    # finish assembling complete downstream costs
    # (the above computation may be missing terms from model)
    for site in non_reparam_nodes:
        children_in_model = set()
        for node in downstream_guide_cost_nodes[site]:
            children_in_model.update(model_trace.successors(node))
        # remove terms accounted for above
        children_in_model.difference_update(downstream_guide_cost_nodes[site])
        for child in children_in_model:
            assert model_trace.nodes[child]["type"] == "sample"
            downstream_costs[site].add(
                (stacks[child], model_trace.nodes[child]["log_prob"])
            )
            downstream_guide_cost_nodes[site].update([child])

    for k in non_reparam_nodes:
        downstream_costs[k] = downstream_costs[k].sum_to(
            guide_trace.nodes[k]["cond_indep_stack"]
        )

    return downstream_costs, downstream_guide_cost_nodes


def _compute_elbo(model_trace, guide_trace):
    # In ref [1], section 3.2, the part of the surrogate loss computed here is
    # \sum{cost}, which in this case is the ELBO. Instead of using the ELBO,
    # this implementation uses a surrogate ELBO which modifies some entropy
    # terms depending on the parameterization. This reduces the variance of the
    # gradient under some conditions.

    elbo = 0.0
    surrogate_elbo = 0.0
    baseline_loss = 0.0
    # mapping from non-reparameterizable sample sites to cost terms influenced by each of them
    downstream_costs = defaultdict(lambda: MultiFrameTensor())

    # Bring log p(x, z|...) terms into both the ELBO and the surrogate
    for name, site in model_trace.nodes.items():
        if site["type"] == "sample":
            elbo += site["log_prob_sum"]
            surrogate_elbo += site["log_prob_sum"]
            # add the log_prob to each non-reparam sample site upstream
            for key in get_provenance(site["log_prob_sum"]):
                downstream_costs[key].add((site["cond_indep_stack"], site["log_prob"]))

    # Bring log q(z|...) terms into the ELBO, and effective terms into the
    # surrogate. Depending on the parameterization of a site, its log q(z|...)
    # cost term may not contribute (in expectation) to the gradient. To reduce
    # the variance under some conditions, the default entropy terms from
    # site[`score_parts`] are used.
    for name, site in guide_trace.nodes.items():
        if site["type"] == "sample":
            elbo -= site["log_prob_sum"]
            entropy_term = site["score_parts"].entropy_term
            # For fully reparameterized terms, this entropy_term is log q(z|...)
            # For fully non-reparameterized terms, it is zero
            if not is_identically_zero(entropy_term):
                surrogate_elbo -= entropy_term.sum()
            # add the -log_prob to each non-reparam sample site upstream
            for key in get_provenance(site["log_prob_sum"]):
                downstream_costs[key].add((site["cond_indep_stack"], -site["log_prob"]))

    # construct all the reinforce-like terms.
    # we include only downstream costs to reduce variance
    # optionally include baselines to further reduce variance
    for node, downstream_cost in downstream_costs.items():
        guide_site = guide_trace.nodes[node]
        downstream_cost = downstream_cost.sum_to(guide_site["cond_indep_stack"])
        score_function = guide_site["score_parts"].score_function

        use_baseline, baseline_loss_term, baseline = _construct_baseline(
            node, guide_site, downstream_cost
        )

        if use_baseline:
            downstream_cost = downstream_cost - baseline
            baseline_loss = baseline_loss + baseline_loss_term

        surrogate_elbo += (score_function * downstream_cost.detach()).sum()

    surrogate_loss = -surrogate_elbo + baseline_loss
    return detach_provenance(elbo), detach_provenance(surrogate_loss)


[docs]class TrackNonReparam(Messenger): """ Track non-reparameterizable sample sites. **References:** 1. *Nonstandard Interpretations of Probabilistic Programs for Efficient Inference*, David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind **Example:** .. doctest:: >>> import torch >>> import pyro >>> import pyro.distributions as dist >>> from pyro.infer.tracegraph_elbo import TrackNonReparam >>> from pyro.ops.provenance import get_provenance >>> from pyro.poutine import trace >>> def model(): ... probs_a = torch.tensor([0.3, 0.7]) ... probs_b = torch.tensor([[0.1, 0.9], [0.8, 0.2]]) ... probs_c = torch.tensor([[0.5, 0.5], [0.6, 0.4]]) ... a = pyro.sample("a", dist.Categorical(probs_a)) ... b = pyro.sample("b", dist.Categorical(probs_b[a])) ... pyro.sample("c", dist.Categorical(probs_c[b]), obs=torch.tensor(0)) >>> with TrackNonReparam(): ... model_tr = trace(model).get_trace() >>> model_tr.compute_log_prob() >>> print(get_provenance(model_tr.nodes["a"]["log_prob"])) # doctest: +SKIP frozenset({'a'}) >>> print(get_provenance(model_tr.nodes["b"]["log_prob"])) # doctest: +SKIP frozenset({'b', 'a'}) >>> print(get_provenance(model_tr.nodes["c"]["log_prob"])) # doctest: +SKIP frozenset({'b', 'a'}) """ def _pyro_post_sample(self, msg): if ( msg["type"] == "sample" and not isinstance(msg["fn"], _Subsample) and not msg["is_observed"] and not getattr(msg["fn"], "has_rsample", False) ): provenance = frozenset({msg["name"]}) msg["value"] = track_provenance(msg["value"], provenance)
[docs]class TraceGraph_ELBO(ELBO): """ A TraceGraph implementation of ELBO-based SVI. The gradient estimator is constructed along the lines of reference [1] specialized to the case of the ELBO. It supports arbitrary dependency structure for the model and guide as well as baselines for non-reparameterizable random variables. Fine-grained conditional dependency information as recorded in the :class:`~pyro.poutine.trace.Trace` is used to reduce the variance of the gradient estimator. In particular provenance tracking [3] is used to find the ``cost`` terms that depend on each non-reparameterizable sample site. References [1] `Gradient Estimation Using Stochastic Computation Graphs`, John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel [2] `Neural Variational Inference and Learning in Belief Networks` Andriy Mnih, Karol Gregor [3] `Nonstandard Interpretations of Probabilistic Programs for Efficient Inference`, David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind """ def _get_trace(self, model, guide, args, kwargs): """ Returns a single trace from the guide, and the model that is run against it. """ with TrackNonReparam(): model_trace, guide_trace = get_importance_trace( "dense", self.max_plate_nesting, model, guide, args, kwargs ) if is_validation_enabled(): check_if_enumerated(guide_trace) return model_trace, guide_trace
[docs] def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Evaluates the ELBO with an estimator that uses num_particles many samples/particles. """ elbo = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item( guide_trace.log_prob_sum() ) elbo += elbo_particle / float(self.num_particles) loss = -elbo warn_if_nan(loss, "loss") return loss
[docs] def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. If baselines are present, a baseline loss is also constructed and differentiated. """ elbo, surrogate_loss = self._loss_and_surrogate_loss(model, guide, args, kwargs) torch_backward(surrogate_loss, retain_graph=self.retain_graph) elbo = torch_item(elbo) loss = -elbo warn_if_nan(loss, "loss") return loss
def _loss_and_surrogate_loss(self, model, guide, args, kwargs): loss = 0.0 surrogate_loss = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): lp, slp = self._loss_and_surrogate_loss_particle(model_trace, guide_trace) loss += lp surrogate_loss += slp loss /= self.num_particles surrogate_loss /= self.num_particles return loss, surrogate_loss def _loss_and_surrogate_loss_particle(self, model_trace, guide_trace): elbo, surrogate_loss = _compute_elbo(model_trace, guide_trace) return elbo, surrogate_loss
[docs]class JitTraceGraph_ELBO(TraceGraph_ELBO): """ Like :class:`TraceGraph_ELBO` but uses :func:`torch.jit.trace` to compile :meth:`loss_and_grads`. This works only for a limited set of models: - Models must have static structure. - Models must not depend on any global data (except the param store). - All model inputs that are tensors must be passed in via ``*args``. - All model inputs that are *not* tensors must be passed in via ``**kwargs``, and compilation will be triggered once per unique ``**kwargs``. """
[docs] def loss_and_grads(self, model, guide, *args, **kwargs): kwargs["_pyro_model_id"] = id(model) kwargs["_pyro_guide_id"] = id(guide) if getattr(self, "_jit_loss_and_surrogate_loss", None) is None: # build a closure for loss_and_surrogate_loss weakself = weakref.ref(self) @pyro.ops.jit.trace( ignore_warnings=self.ignore_jit_warnings, jit_options=self.jit_options ) def jit_loss_and_surrogate_loss(*args, **kwargs): kwargs.pop("_pyro_model_id") kwargs.pop("_pyro_guide_id") self = weakself() return self._loss_and_surrogate_loss(model, guide, args, kwargs) self._jit_loss_and_surrogate_loss = jit_loss_and_surrogate_loss elbo, surrogate_loss = self._jit_loss_and_surrogate_loss(*args, **kwargs) surrogate_loss.backward( retain_graph=self.retain_graph ) # triggers jit compilation loss = -elbo.item() warn_if_nan(loss, "loss") return loss