# Source code for pyro.infer.trace_elbo

from __future__ import absolute_import, division, print_function

import weakref

import pyro
import pyro.ops.jit
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
from pyro.infer.enum import get_importance_trace
from pyro.infer.util import MultiFrameTensor, get_plate_stacks, is_validation_enabled, torch_item
from pyro.util import check_if_enumerated, warn_if_nan

def _compute_log_r(model_trace, guide_trace):
log_r = MultiFrameTensor()
stacks = get_plate_stacks(model_trace)
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
log_r_term = model_site["log_prob"]
if not model_site["is_observed"]:
log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"]
[docs]class Trace_ELBO(ELBO): """ A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and [2]. There are no restrictions on the dependency structure of the model or the guide. The gradient estimator includes partial Rao-Blackwellization for reducing the variance of the estimator when non-reparameterizable random variables are present. The Rao-Blackwellization is partial in that it only uses conditional independence information that is marked by :class:~pyro.plate contexts. For more fine-grained Rao-Blackwellization, see :class:~pyro.infer.tracegraph_elbo.TraceGraph_ELBO. References [1] Automated Variational Inference in Probabilistic Programming, David Wingate, Theo Weber [2] Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei """ def _get_trace(self, model, guide, *args, **kwargs): """ Returns a single trace from the guide, and the model that is run against it. """ model_trace, guide_trace = get_importance_trace( "flat", self.max_plate_nesting, model, guide, *args, **kwargs) if is_validation_enabled(): check_if_enumerated(guide_trace) return model_trace, guide_trace
[docs]class JitTrace_ELBO(Trace_ELBO): """ Like :class:Trace_ELBO but uses :func:pyro.ops.jit.compile to compile :meth:loss_and_grads. This works only for a limited set of models: - Models must have static structure. - Models must not depend on any global data (except the param store). - All model inputs that are tensors must be passed in via *args. - All model inputs that are *not* tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs. """