# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import queue
import warnings
import torch
import pyro.poutine as poutine
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
from pyro.infer.enum import (
get_importance_trace,
iter_discrete_escape,
iter_discrete_extend,
)
from pyro.infer.util import compute_site_dice_factor, is_validation_enabled, torch_item
from pyro.ops import packed
from pyro.ops.contract import einsum
from pyro.poutine.enum_messenger import EnumMessenger
from pyro.util import check_traceenum_requirements, warn_if_nan
def _compute_dice_factors(model_trace, guide_trace):
"""
compute per-site DiCE log-factors for non-reparameterized proposal sites
this logic is adapted from pyro.infer.util.Dice.__init__
"""
log_probs = []
for role, trace in zip(("model", "guide"), (model_trace, guide_trace)):
for name, site in trace.nodes.items():
if site["type"] != "sample" or site["is_observed"]:
continue
if role == "model" and name in guide_trace:
continue
log_prob, log_denom = compute_site_dice_factor(site)
if not is_identically_zero(log_denom):
dims = log_prob._pyro_dims
log_prob = log_prob - log_denom
log_prob._pyro_dims = dims
if not is_identically_zero(log_prob):
log_probs.append(log_prob)
return log_probs
def _compute_tmc_factors(model_trace, guide_trace):
"""
compute per-site log-factors for all observed and unobserved variables
log-factors are log(p / q) for unobserved sites and log(p) for observed sites
"""
log_factors = []
for name, site in guide_trace.nodes.items():
if site["type"] != "sample" or site["is_observed"]:
continue
log_proposal = site["packed"]["log_prob"]
log_factors.append(packed.neg(log_proposal))
for name, site in model_trace.nodes.items():
if site["type"] != "sample":
continue
if (
site["name"] not in guide_trace
and not site["is_observed"]
and site["infer"].get("enumerate", None) == "parallel"
and site["infer"].get("num_samples", -1) > 0
):
# site was sampled from the prior
log_proposal = packed.neg(site["packed"]["log_prob"])
log_factors.append(log_proposal)
log_factors.append(site["packed"]["log_prob"])
return log_factors
def _compute_tmc_estimate(model_trace, guide_trace):
"""
Use :func:`~pyro.ops.contract.einsum` to compute the Tensor Monte Carlo
estimate of the marginal likelihood given parallel-sampled traces.
"""
# factors
log_factors = _compute_tmc_factors(model_trace, guide_trace)
log_factors += _compute_dice_factors(model_trace, guide_trace)
if not log_factors:
return 0.0
# loss
eqn = ",".join([f._pyro_dims for f in log_factors]) + "->"
plates = "".join(
frozenset().union(
list(model_trace.plate_to_symbol.values()),
list(guide_trace.plate_to_symbol.values()),
)
)
(tmc,) = einsum(
eqn,
*log_factors,
plates=plates,
backend="pyro.ops.einsum.torch_log",
modulo_total=False
)
return tmc
[docs]class TraceTMC_ELBO(ELBO):
"""
A trace-based implementation of Tensor Monte Carlo [1]
by way of Tensor Variable Elimination [2] that supports:
- local parallel sampling over any sample site in the model or guide
- exhaustive enumeration over any sample site in the model or guide
To take multiple samples, mark the site with
``infer={'enumerate': 'parallel', 'num_samples': N}``.
To configure all sites in a model or guide at once,
use :func:`~pyro.infer.enum.config_enumerate` .
To enumerate or sample a sample site in the ``model``,
mark the site and ensure the site does not appear in the ``guide``.
This assumes restricted dependency structure on the model and guide:
variables outside of an :class:`~pyro.plate` can never depend on
variables inside that :class:`~pyro.plate` .
References
[1] `Tensor Monte Carlo: Particle Methods for the GPU Era`,
Laurence Aitchison (2018)
[2] `Tensor Variable Elimination for Plated Factor Graphs`,
Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu, Neeraj Pradhan,
Alexander Rush, Noah Goodman (2019)
"""
def _get_trace(self, model, guide, args, kwargs):
"""
Returns a single trace from the guide, and the model that is run
against it.
"""
model_trace, guide_trace = get_importance_trace(
"flat", self.max_plate_nesting, model, guide, args, kwargs
)
if is_validation_enabled():
check_traceenum_requirements(model_trace, guide_trace)
has_enumerated_sites = any(
site["infer"].get("enumerate")
for trace in (guide_trace, model_trace)
for name, site in trace.nodes.items()
if site["type"] == "sample"
)
if self.strict_enumeration_warning and not has_enumerated_sites:
warnings.warn(
"Found no sample sites configured for enumeration. "
"If you want to enumerate sites, you need to @config_enumerate or set "
'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
"If you do not want to enumerate, consider using Trace_ELBO instead."
)
model_trace.compute_score_parts()
guide_trace.pack_tensors()
model_trace.pack_tensors(guide_trace.plate_to_symbol)
return model_trace, guide_trace
def _get_traces(self, model, guide, args, kwargs):
"""
Runs the guide and runs the model against the guide with
the result packaged as a trace generator.
"""
if self.max_plate_nesting == float("inf"):
self._guess_max_plate_nesting(model, guide, args, kwargs)
if self.vectorize_particles:
guide = self._vectorized_num_particles(guide)
model = self._vectorized_num_particles(model)
# Enable parallel enumeration over the vectorized guide and model.
# The model allocates enumeration dimensions after (to the left of) the guide,
# accomplished by preserving the _ENUM_ALLOCATOR state after the guide call.
guide_enum = EnumMessenger(first_available_dim=-1 - self.max_plate_nesting)
model_enum = EnumMessenger() # preserve _ENUM_ALLOCATOR state
guide = guide_enum(guide)
model = model_enum(model)
q = queue.LifoQueue()
guide = poutine.queue(
guide, q, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend
)
for i in range(1 if self.vectorize_particles else self.num_particles):
q.put(poutine.Trace())
while not q.empty():
yield self._get_trace(model, guide, args, kwargs)
[docs] def differentiable_loss(self, model, guide, *args, **kwargs):
"""
:returns: a differentiable estimate of the marginal log-likelihood
:rtype: torch.Tensor
:raises ValueError: if the ELBO is not differentiable (e.g. is
identically zero)
Computes a differentiable TMC estimate using ``num_particles`` many samples
(particles). The result should be infinitely differentiable (as long
as underlying derivatives have been implemented).
"""
elbo = 0.0
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
elbo_particle = _compute_tmc_estimate(model_trace, guide_trace)
if is_identically_zero(elbo_particle):
continue
elbo = elbo + elbo_particle
elbo = elbo / self.num_particles
loss = -elbo
warn_if_nan(loss, "loss")
return loss
[docs] def loss(self, model, guide, *args, **kwargs):
with torch.no_grad():
loss = self.differentiable_loss(model, guide, *args, **kwargs)
if is_identically_zero(loss) or not loss.requires_grad:
return torch_item(loss)
return loss.item()
[docs] def loss_and_grads(self, model, guide, *args, **kwargs):
loss = self.differentiable_loss(model, guide, *args, **kwargs)
if is_identically_zero(loss) or not loss.requires_grad:
return torch_item(loss)
loss.backward()
return loss.item()