Source code for pyro.contrib.funsor.infer.elbo

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pyro.ops.jit
from pyro.infer import ELBO as _OrigELBO
from pyro.util import ignore_jit_warnings


[docs]class ELBO(_OrigELBO): def _get_trace(self, *args, **kwargs): raise ValueError("shouldn't be here!")
[docs] def differentiable_loss(self, model, guide, *args, **kwargs): raise NotImplementedError("Must implement differentiable_loss")
[docs] def loss(self, model, guide, *args, **kwargs): return self.differentiable_loss(model, guide, *args, **kwargs).detach().item()
[docs] def loss_and_grads(self, model, guide, *args, **kwargs): loss = self.differentiable_loss(model, guide, *args, **kwargs) loss.backward() return loss.item()
[docs]class Jit_ELBO(ELBO):
[docs] def differentiable_loss(self, model, guide, *args, **kwargs): kwargs["_model_id"] = id(model) kwargs["_guide_id"] = id(guide) if getattr(self, "_differentiable_loss", None) is None: # build a closure for differentiable_loss superself = super() @pyro.ops.jit.trace( ignore_warnings=self.ignore_jit_warnings, jit_options=self.jit_options ) def differentiable_loss(*args, **kwargs): kwargs.pop("_model_id") kwargs.pop("_guide_id") with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]): return superself.differentiable_loss(model, guide, *args, **kwargs) self._differentiable_loss = differentiable_loss return self._differentiable_loss(*args, **kwargs)