Source code for pyro.infer.trace_elbo

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import weakref

import pyro
import pyro.ops.jit
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
from pyro.infer.enum import get_importance_trace
from pyro.infer.util import (
    MultiFrameTensor,
    get_plate_stacks,
    is_validation_enabled,
    torch_item,
)
from pyro.util import check_if_enumerated, warn_if_nan


def _compute_log_r(model_trace, guide_trace):
    log_r = MultiFrameTensor()
    stacks = get_plate_stacks(model_trace)
    for name, model_site in model_trace.nodes.items():
        if model_site["type"] == "sample":
            log_r_term = model_site["log_prob"]
            if not model_site["is_observed"]:
                log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"]
            log_r.add((stacks[name], log_r_term.detach()))
    return log_r


[docs]class Trace_ELBO(ELBO): """ A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and [2]. There are no restrictions on the dependency structure of the model or the guide. The gradient estimator includes partial Rao-Blackwellization for reducing the variance of the estimator when non-reparameterizable random variables are present. The Rao-Blackwellization is partial in that it only uses conditional independence information that is marked by :class:`~pyro.plate` contexts. For more fine-grained Rao-Blackwellization, see :class:`~pyro.infer.tracegraph_elbo.TraceGraph_ELBO`. References [1] Automated Variational Inference in Probabilistic Programming, David Wingate, Theo Weber [2] Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei """ def _get_trace(self, model, guide, args, kwargs): """ Returns a single trace from the guide, and the model that is run against it. """ model_trace, guide_trace = get_importance_trace( "flat", self.max_plate_nesting, model, guide, args, kwargs ) if is_validation_enabled(): check_if_enumerated(guide_trace) return model_trace, guide_trace
[docs] def loss(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Evaluates the ELBO with an estimator that uses num_particles many samples/particles. """ elbo = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): elbo_particle = torch_item(model_trace.log_prob_sum()) - torch_item( guide_trace.log_prob_sum() ) elbo += elbo_particle / self.num_particles loss = -elbo warn_if_nan(loss, "loss") return loss
def _differentiable_loss_particle(self, model_trace, guide_trace): elbo_particle = 0 surrogate_elbo_particle = 0 log_r = None # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo_particle = elbo_particle + torch_item(site["log_prob_sum"]) surrogate_elbo_particle = surrogate_elbo_particle + site["log_prob_sum"] for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site["score_parts"] elbo_particle = elbo_particle - torch_item(site["log_prob_sum"]) if not is_identically_zero(entropy_term): surrogate_elbo_particle = ( surrogate_elbo_particle - entropy_term.sum() ) if not is_identically_zero(score_function_term): if log_r is None: log_r = _compute_log_r(model_trace, guide_trace) site = log_r.sum_to(site["cond_indep_stack"]) surrogate_elbo_particle = ( surrogate_elbo_particle + (site * score_function_term).sum() ) return -elbo_particle, -surrogate_elbo_particle
[docs] def differentiable_loss(self, model, guide, *args, **kwargs): """ Computes the surrogate loss that can be differentiated with autograd to produce gradient estimates for the model and guide parameters """ loss = 0.0 surrogate_loss = 0.0 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): loss_particle, surrogate_loss_particle = self._differentiable_loss_particle( model_trace, guide_trace ) surrogate_loss += surrogate_loss_particle / self.num_particles loss += loss_particle / self.num_particles warn_if_nan(surrogate_loss, "loss") return loss + (surrogate_loss - torch_item(surrogate_loss))
[docs] def loss_and_grads(self, model, guide, *args, **kwargs): """ :returns: returns an estimate of the ELBO :rtype: float Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. """ loss = 0.0 # grab a trace from the generator for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): loss_particle, surrogate_loss_particle = self._differentiable_loss_particle( model_trace, guide_trace ) loss += loss_particle / self.num_particles # collect parameters to train from model and guide trainable_params = any( site["type"] == "param" for trace in (model_trace, guide_trace) for site in trace.nodes.values() ) if trainable_params and getattr( surrogate_loss_particle, "requires_grad", False ): surrogate_loss_particle = surrogate_loss_particle / self.num_particles surrogate_loss_particle.backward(retain_graph=self.retain_graph) warn_if_nan(loss, "loss") return loss
[docs]class JitTrace_ELBO(Trace_ELBO): """ Like :class:`Trace_ELBO` but uses :func:`pyro.ops.jit.compile` to compile :meth:`loss_and_grads`. This works only for a limited set of models: - Models must have static structure. - Models must not depend on any global data (except the param store). - All model inputs that are tensors must be passed in via ``*args``. - All model inputs that are *not* tensors must be passed in via ``**kwargs``, and compilation will be triggered once per unique ``**kwargs``. """
[docs] def loss_and_surrogate_loss(self, model, guide, *args, **kwargs): kwargs["_pyro_model_id"] = id(model) kwargs["_pyro_guide_id"] = id(guide) if getattr(self, "_loss_and_surrogate_loss", None) is None: # build a closure for loss_and_surrogate_loss weakself = weakref.ref(self) @pyro.ops.jit.trace( ignore_warnings=self.ignore_jit_warnings, jit_options=self.jit_options ) def loss_and_surrogate_loss(*args, **kwargs): kwargs.pop("_pyro_model_id") kwargs.pop("_pyro_guide_id") self = weakself() loss = 0.0 surrogate_loss = 0.0 for model_trace, guide_trace in self._get_traces( model, guide, args, kwargs ): elbo_particle = 0 surrogate_elbo_particle = 0 log_r = None # compute elbo and surrogate elbo for name, site in model_trace.nodes.items(): if site["type"] == "sample": elbo_particle = elbo_particle + site["log_prob_sum"] surrogate_elbo_particle = ( surrogate_elbo_particle + site["log_prob_sum"] ) for name, site in guide_trace.nodes.items(): if site["type"] == "sample": log_prob, score_function_term, entropy_term = site[ "score_parts" ] elbo_particle = elbo_particle - site["log_prob_sum"] if not is_identically_zero(entropy_term): surrogate_elbo_particle = ( surrogate_elbo_particle - entropy_term.sum() ) if not is_identically_zero(score_function_term): if log_r is None: log_r = _compute_log_r(model_trace, guide_trace) site = log_r.sum_to(site["cond_indep_stack"]) surrogate_elbo_particle = ( surrogate_elbo_particle + (site * score_function_term).sum() ) loss = loss - elbo_particle / self.num_particles surrogate_loss = ( surrogate_loss - surrogate_elbo_particle / self.num_particles ) return loss, surrogate_loss self._loss_and_surrogate_loss = loss_and_surrogate_loss return self._loss_and_surrogate_loss(*args, **kwargs)
[docs] def differentiable_loss(self, model, guide, *args, **kwargs): loss, surrogate_loss = self.loss_and_surrogate_loss( model, guide, *args, **kwargs ) warn_if_nan(loss, "loss") return loss + (surrogate_loss - surrogate_loss.detach())
[docs] def loss_and_grads(self, model, guide, *args, **kwargs): loss, surrogate_loss = self.loss_and_surrogate_loss( model, guide, *args, **kwargs ) surrogate_loss.backward() loss = loss.item() warn_if_nan(loss, "loss") return loss