# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import queue
import warnings
import weakref
from collections import OrderedDict
import torch
from opt_einsum import shared_intermediates
import pyro
import pyro.distributions as dist
import pyro.ops.jit
import pyro.poutine as poutine
from pyro.distributions.util import is_identically_zero
from pyro.infer.elbo import ELBO
from pyro.infer.enum import (
get_importance_trace,
iter_discrete_escape,
iter_discrete_extend,
)
from pyro.infer.util import Dice, is_validation_enabled
from pyro.ops import packed
from pyro.ops.contract import contract_tensor_tree, contract_to_tensor
from pyro.ops.rings import SampleRing
from pyro.poutine.enum_messenger import EnumMessenger
from pyro.util import check_traceenum_requirements, ignore_jit_warnings, warn_if_nan
@ignore_jit_warnings()
def _get_common_scale(scales):
# Check that all enumerated sites share a common subsampling scale.
# Note that we use a cheap weak comparison by id rather than tensor value, because
# (1) it is expensive to compare tensors by value, and (2) tensors must agree not
# only in value but at all derivatives.
scales_set = set()
for scale in scales:
if isinstance(scale, torch.Tensor) and scale.dim():
raise ValueError("enumeration only supports scalar poutine.scale")
scales_set.add(float(scale))
if len(scales_set) != 1:
raise ValueError(
"Expected all enumerated sample sites to share a common poutine.scale, "
"but found {} different scales.".format(len(scales_set))
)
return scales[0]
def _check_model_guide_enumeration_constraint(model_enum_sites, guide_trace):
min_ordinal = frozenset.intersection(*model_enum_sites.keys())
for name, site in guide_trace.nodes.items():
if site["type"] == "sample" and site["infer"].get("_enumerate_dim") is not None:
for f in site["cond_indep_stack"]:
if (
f.vectorized
and guide_trace.plate_to_symbol[f.name] not in min_ordinal
):
raise ValueError(
"Expected model enumeration to be no more global than guide enumeration, "
"but found model enumeration sites upstream of guide site '{}' in plate('{}'). "
"Try converting some model enumeration sites to guide enumeration sites.".format(
name, f.name
)
)
def _check_tmc_elbo_constraint(model_trace, guide_trace):
num_samples = frozenset(
site["infer"].get("num_samples")
for site in guide_trace.nodes.values()
if site["type"] == "sample"
and site["infer"].get("enumerate") == "parallel"
and site["infer"].get("num_samples") is not None
)
if len(num_samples) > 1:
warnings.warn(
"\n".join(
[
"Using different numbers of Monte Carlo samples for different guide sites in TraceEnum_ELBO.",
"This may be biased if the guide is not factorized",
]
),
UserWarning,
)
for name, site in model_trace.nodes.items():
if (
site["type"] == "sample"
and site["infer"].get("enumerate", None) == "parallel"
and site["infer"].get("num_samples", None)
and name not in guide_trace
):
warnings.warn(
"\n".join(
[
"Site {} is multiply sampled in model,".format(site["name"]),
"expect incorrect gradient estimates from TraceEnum_ELBO.",
"Consider using exact enumeration or guide sampling if possible.",
]
),
RuntimeWarning,
)
def _find_ordinal(trace, site):
return frozenset(
trace.plate_to_symbol[f.name] for f in site["cond_indep_stack"] if f.vectorized
)
# TODO move this logic into a poutine
def _compute_model_factors(model_trace, guide_trace):
# y depends on x iff ordering[x] <= ordering[y]
# TODO refine this coarse dependency ordering using time.
ordering = {
name: _find_ordinal(trace, site)
for trace in (model_trace, guide_trace)
for name, site in trace.nodes.items()
if site["type"] == "sample"
}
# Collect model sites that may have been enumerated in the model.
cost_sites = OrderedDict()
enum_sites = OrderedDict()
enum_dims = set()
non_enum_dims = set().union(*ordering.values())
for name, site in model_trace.nodes.items():
if site["type"] == "sample":
if name in guide_trace.nodes:
cost_sites.setdefault(ordering[name], []).append(site)
non_enum_dims.update(
guide_trace.nodes[name]["packed"]["log_prob"]._pyro_dims
)
elif site["infer"].get("_enumerate_dim") is None:
cost_sites.setdefault(ordering[name], []).append(site)
else:
enum_sites.setdefault(ordering[name], []).append(site)
enum_dims.update(site["packed"]["log_prob"]._pyro_dims)
enum_dims -= non_enum_dims
log_factors = OrderedDict()
scale = 1
if not enum_sites:
marginal_costs = OrderedDict(
(t, [site["packed"]["log_prob"] for site in sites_t])
for t, sites_t in cost_sites.items()
)
return marginal_costs, log_factors, ordering, enum_dims, scale
_check_model_guide_enumeration_constraint(enum_sites, guide_trace)
# Marginalize out all variables that have been enumerated in the model.
marginal_costs = OrderedDict()
scales = []
for t, sites_t in cost_sites.items():
for site in sites_t:
if enum_dims.isdisjoint(site["packed"]["log_prob"]._pyro_dims):
# For sites that do not depend on an enumerated variable, proceed as usual.
marginal_costs.setdefault(t, []).append(site["packed"]["log_prob"])
else:
# For sites that depend on an enumerated variable, we need to apply
# the mask inside- and the scale outside- of the log expectation.
if "masked_log_prob" not in site["packed"]:
site["packed"]["masked_log_prob"] = packed.scale_and_mask(
site["packed"]["unscaled_log_prob"], mask=site["packed"]["mask"]
)
cost = site["packed"]["masked_log_prob"]
log_factors.setdefault(t, []).append(cost)
scales.append(site["scale"])
for t, sites_t in enum_sites.items():
# TODO refine this coarse dependency ordering using time and tensor shapes.
for site in sites_t:
logprob = site["packed"]["unscaled_log_prob"]
log_factors.setdefault(t, []).append(logprob)
scales.append(site["scale"])
scale = _get_common_scale(scales)
return marginal_costs, log_factors, ordering, enum_dims, scale
def _compute_dice_elbo(model_trace, guide_trace):
# Accumulate marginal model costs.
marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors(
model_trace, guide_trace
)
if log_factors:
dim_to_size = {}
for terms in log_factors.values():
for term in terms:
dim_to_size.update(zip(term._pyro_dims, term.shape))
# Note that while most applications of tensor message passing use the
# contract_to_tensor() interface and can be easily refactored to use ubersum(),
# the application here relies on contract_tensor_tree() to extract the dependency
# structure of different log_prob terms, which is used by Dice to eliminate
# zero-expectation terms. One possible refactoring would be to replace
# contract_to_tensor() with a RaggedTensor -> Tensor contraction operation, but
# replace contract_tensor_tree() with a RaggedTensor -> RaggedTensor contraction
# that preserves some dependency structure.
with shared_intermediates() as cache:
ring = SampleRing(cache=cache, dim_to_size=dim_to_size)
log_factors = contract_tensor_tree(log_factors, sum_dims, ring=ring)
model_trace._sharing_cache = cache # For TraceEnumSample_ELBO.
for t, log_factors_t in log_factors.items():
marginal_costs_t = marginal_costs.setdefault(t, [])
for term in log_factors_t:
term = packed.scale_and_mask(term, scale=scale)
marginal_costs_t.append(term)
costs = marginal_costs
# Accumulate negative guide costs.
for name, site in guide_trace.nodes.items():
if site["type"] == "sample":
cost = packed.neg(site["packed"]["log_prob"])
costs.setdefault(ordering[name], []).append(cost)
return Dice(guide_trace, ordering).compute_expectation(costs)
def _make_dist(dist_, logits):
# Reshape for Bernoulli vs Categorical, OneHotCategorical, etc..
if isinstance(dist_, dist.Bernoulli):
logits = logits[..., 1] - logits[..., 0]
return type(dist_)(logits=logits)
def _compute_marginals(model_trace, guide_trace):
args = _compute_model_factors(model_trace, guide_trace)
marginal_costs, log_factors, ordering, sum_dims, scale = args
marginal_dists = OrderedDict()
with shared_intermediates() as cache:
for name, site in model_trace.nodes.items():
if (
site["type"] != "sample"
or name in guide_trace.nodes
or site["infer"].get("_enumerate_dim") is None
):
continue
enum_dim = site["infer"]["_enumerate_dim"]
enum_symbol = site["infer"]["_enumerate_symbol"]
ordinal = _find_ordinal(model_trace, site)
logits = contract_to_tensor(
log_factors,
sum_dims,
target_ordinal=ordinal,
target_dims={enum_symbol},
cache=cache,
)
logits = packed.unpack(logits, model_trace.symbol_to_dim)
logits = logits.unsqueeze(-1).transpose(-1, enum_dim - 1)
while logits.shape[0] == 1:
logits = logits.squeeze(0)
marginal_dists[name] = _make_dist(site["fn"], logits)
return marginal_dists
[docs]class BackwardSampleMessenger(pyro.poutine.messenger.Messenger):
"""
Implements forward filtering / backward sampling for sampling
from the joint posterior distribution
"""
def __init__(self, enum_trace, guide_trace):
self.enum_trace = enum_trace
args = _compute_model_factors(enum_trace, guide_trace)
self.log_factors = args[1]
self.sum_dims = args[3]
def __enter__(self):
self.cache = {}
return super().__enter__()
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
assert not self.sum_dims, self.sum_dims
return super().__exit__(exc_type, exc_value, traceback)
def _pyro_sample(self, msg):
enum_msg = self.enum_trace.nodes.get(msg["name"])
if enum_msg is None:
return
enum_symbol = enum_msg["infer"].get("_enumerate_symbol")
if enum_symbol is None:
return
enum_dim = enum_msg["infer"]["_enumerate_dim"]
with shared_intermediates(self.cache):
ordinal = _find_ordinal(self.enum_trace, msg)
logits = contract_to_tensor(
self.log_factors,
self.sum_dims,
target_ordinal=ordinal,
target_dims={enum_symbol},
cache=self.cache,
)
logits = packed.unpack(logits, self.enum_trace.symbol_to_dim)
logits = logits.unsqueeze(-1).transpose(-1, enum_dim - 1)
while logits.shape[0] == 1:
logits = logits.squeeze(0)
msg["fn"] = _make_dist(msg["fn"], logits)
def _pyro_post_sample(self, msg):
enum_msg = self.enum_trace.nodes.get(msg["name"])
if enum_msg is None:
return
enum_symbol = enum_msg["infer"].get("_enumerate_symbol")
if enum_symbol is None:
return
value = packed.pack(msg["value"].long(), enum_msg["infer"]["_dim_to_symbol"])
assert enum_symbol not in value._pyro_dims
for t, terms in self.log_factors.items():
for i, term in enumerate(terms):
if enum_symbol in term._pyro_dims:
terms[i] = packed.gather(term, value, enum_symbol)
self.sum_dims.remove(enum_symbol)
[docs]class TraceEnum_ELBO(ELBO):
"""
A trace implementation of ELBO-based SVI that supports
- exhaustive enumeration over discrete sample sites, and
- local parallel sampling over any sample site in the guide.
To enumerate over a sample site in the ``guide``, mark the site with either
``infer={'enumerate': 'sequential'}`` or
``infer={'enumerate': 'parallel'}``. To configure all guide sites at once,
use :func:`~pyro.infer.enum.config_enumerate`. To enumerate over a sample
site in the ``model``, mark the site ``infer={'enumerate': 'parallel'}``
and ensure the site does not appear in the ``guide``.
This assumes restricted dependency structure on the model and guide:
variables outside of an :class:`~pyro.plate` can never depend on
variables inside that :class:`~pyro.plate`.
"""
def _get_trace(self, model, guide, args, kwargs):
"""
Returns a single trace from the guide, and the model that is run
against it.
"""
model_trace, guide_trace = get_importance_trace(
"flat", self.max_plate_nesting, model, guide, args, kwargs
)
if is_validation_enabled():
check_traceenum_requirements(model_trace, guide_trace)
_check_tmc_elbo_constraint(model_trace, guide_trace)
has_enumerated_sites = any(
site["infer"].get("enumerate")
for trace in (guide_trace, model_trace)
for name, site in trace.nodes.items()
if site["type"] == "sample"
)
if self.strict_enumeration_warning and not has_enumerated_sites:
warnings.warn(
"TraceEnum_ELBO found no sample sites configured for enumeration. "
"If you want to enumerate sites, you need to @config_enumerate or set "
'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
"If you do not want to enumerate, consider using Trace_ELBO instead."
)
guide_trace.pack_tensors()
model_trace.pack_tensors(guide_trace.plate_to_symbol)
return model_trace, guide_trace
def _get_traces(self, model, guide, args, kwargs):
"""
Runs the guide and runs the model against the guide with
the result packaged as a trace generator.
"""
if isinstance(poutine.unwrap(guide), poutine.messenger.Messenger):
raise NotImplementedError("TraceEnum_ELBO does not support GuideMessenger")
if self.max_plate_nesting == float("inf"):
self._guess_max_plate_nesting(model, guide, args, kwargs)
if self.vectorize_particles:
guide = self._vectorized_num_particles(guide)
model = self._vectorized_num_particles(model)
# Enable parallel enumeration over the vectorized guide and model.
# The model allocates enumeration dimensions after (to the left of) the guide,
# accomplished by preserving the _ENUM_ALLOCATOR state after the guide call.
guide_enum = EnumMessenger(first_available_dim=-1 - self.max_plate_nesting)
model_enum = EnumMessenger() # preserve _ENUM_ALLOCATOR state
guide = guide_enum(guide)
model = model_enum(model)
q = queue.LifoQueue()
guide = poutine.queue(
guide, q, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend
)
for i in range(1 if self.vectorize_particles else self.num_particles):
q.put(poutine.Trace())
while not q.empty():
yield self._get_trace(model, guide, args, kwargs)
[docs] def loss(self, model, guide, *args, **kwargs):
"""
:returns: an estimate of the ELBO
:rtype: float
Estimates the ELBO using ``num_particles`` many samples (particles).
"""
elbo = 0.0
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
if is_identically_zero(elbo_particle):
continue
elbo += elbo_particle.item() / self.num_particles
loss = -elbo
warn_if_nan(loss, "loss")
return loss
[docs] def differentiable_loss(self, model, guide, *args, **kwargs):
"""
:returns: a differentiable estimate of the ELBO
:rtype: torch.Tensor
:raises ValueError: if the ELBO is not differentiable (e.g. is
identically zero)
Estimates a differentiable ELBO using ``num_particles`` many samples
(particles). The result should be infinitely differentiable (as long
as underlying derivatives have been implemented).
"""
elbo = 0.0
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
if is_identically_zero(elbo_particle):
continue
elbo = elbo + elbo_particle
elbo = elbo / self.num_particles
if not torch.is_tensor(elbo) or not elbo.requires_grad:
raise ValueError("ELBO is cannot be differentiated: {}".format(elbo))
loss = -elbo
warn_if_nan(loss, "loss")
return loss
[docs] def loss_and_grads(self, model, guide, *args, **kwargs):
"""
:returns: an estimate of the ELBO
:rtype: float
Estimates the ELBO using ``num_particles`` many samples (particles).
Performs backward on the ELBO of each particle.
"""
elbo = 0.0
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
if is_identically_zero(elbo_particle):
continue
elbo += elbo_particle.item() / self.num_particles
# collect parameters to train from model and guide
trainable_params = any(
site["type"] == "param"
for trace in (model_trace, guide_trace)
for site in trace.nodes.values()
)
if trainable_params and elbo_particle.requires_grad:
loss_particle = -elbo_particle
(loss_particle / self.num_particles).backward(retain_graph=True)
loss = -elbo
warn_if_nan(loss, "loss")
return loss
[docs] def compute_marginals(self, model, guide, *args, **kwargs):
"""
Computes marginal distributions at each model-enumerated sample site.
:returns: a dict mapping site name to marginal ``Distribution`` object
:rtype: OrderedDict
"""
if self.num_particles != 1:
raise NotImplementedError(
"TraceEnum_ELBO.compute_marginals() is not "
"compatible with multiple particles."
)
model_trace, guide_trace = next(self._get_traces(model, guide, args, kwargs))
for site in guide_trace.nodes.values():
if site["type"] == "sample":
if "_enumerate_dim" in site["infer"] or "_enum_total" in site["infer"]:
raise NotImplementedError(
"TraceEnum_ELBO.compute_marginals() is not "
"compatible with guide enumeration."
)
return _compute_marginals(model_trace, guide_trace)
[docs] def sample_posterior(self, model, guide, *args, **kwargs):
"""
Sample from the joint posterior distribution of all model-enumerated sites given all observations
"""
if self.num_particles != 1:
raise NotImplementedError(
"TraceEnum_ELBO.sample_posterior() is not "
"compatible with multiple particles."
)
with poutine.block(), warnings.catch_warnings():
warnings.filterwarnings("ignore", "Found vars in model but not guide")
model_trace, guide_trace = next(
self._get_traces(model, guide, args, kwargs)
)
for name, site in guide_trace.nodes.items():
if site["type"] == "sample":
if "_enumerate_dim" in site["infer"] or "_enum_total" in site["infer"]:
raise NotImplementedError(
"TraceEnum_ELBO.sample_posterior() is not "
"compatible with guide enumeration."
)
# TODO replace BackwardSample with torch_sample backend to ubersum
with BackwardSampleMessenger(model_trace, guide_trace):
return poutine.replay(model, trace=guide_trace)(*args, **kwargs)
[docs]class JitTraceEnum_ELBO(TraceEnum_ELBO):
"""
Like :class:`TraceEnum_ELBO` but uses :func:`pyro.ops.jit.compile` to
compile :meth:`loss_and_grads`.
This works only for a limited set of models:
- Models must have static structure.
- Models must not depend on any global data (except the param store).
- All model inputs that are tensors must be passed in via ``*args``.
- All model inputs that are *not* tensors must be passed in via
``**kwargs``, and compilation will be triggered once per unique
``**kwargs``.
"""
[docs] def differentiable_loss(self, model, guide, *args, **kwargs):
kwargs["_model_id"] = id(model)
kwargs["_guide_id"] = id(guide)
if getattr(self, "_differentiable_loss", None) is None:
# build a closure for differentiable_loss
weakself = weakref.ref(self)
@pyro.ops.jit.trace(
ignore_warnings=self.ignore_jit_warnings, jit_options=self.jit_options
)
def differentiable_loss(*args, **kwargs):
kwargs.pop("_model_id")
kwargs.pop("_guide_id")
self = weakself()
elbo = 0.0
for model_trace, guide_trace in self._get_traces(
model, guide, args, kwargs
):
elbo = elbo + _compute_dice_elbo(model_trace, guide_trace)
return elbo * (-1.0 / self.num_particles)
self._differentiable_loss = differentiable_loss
return self._differentiable_loss(*args, **kwargs)
[docs] def loss_and_grads(self, model, guide, *args, **kwargs):
differentiable_loss = self.differentiable_loss(model, guide, *args, **kwargs)
differentiable_loss.backward() # this line triggers jit compilation
loss = differentiable_loss.item()
warn_if_nan(loss, "loss")
return loss