Source code for pyro.contrib.funsor.infer.trace_elbo

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import contextlib

import funsor

from pyro.contrib.funsor import to_data, to_funsor
from pyro.contrib.funsor.handlers import enum, plate, replay, trace
from pyro.contrib.funsor.infer import config_enumerate
from pyro.distributions.util import copy_docs_from
from pyro.infer import Trace_ELBO as _OrigTrace_ELBO

from .elbo import ELBO, Jit_ELBO
from .traceenum_elbo import terms_from_trace


[docs]@copy_docs_from(_OrigTrace_ELBO) class Trace_ELBO(ELBO):
[docs] def differentiable_loss(self, model, guide, *args, **kwargs): with enum(), ( plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack() ): guide_tr = trace(config_enumerate(default="flat")(guide)).get_trace( *args, **kwargs ) model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs) model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) log_measures = guide_terms["log_measures"] + model_terms["log_measures"] log_factors = model_terms["log_factors"] + [ -f for f in guide_terms["log_factors"] ] plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"] measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"] elbo = funsor.Integrate( sum(log_measures, to_funsor(0.0)), sum(log_factors, to_funsor(0.0)), measure_vars, ) elbo = elbo.reduce(funsor.ops.add, plate_vars) return -to_data(elbo)
[docs]class JitTrace_ELBO(Jit_ELBO, Trace_ELBO): pass