# Source code for pyro.infer.energy_distance

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import operator
from collections import OrderedDict
from functools import reduce

import torch

import pyro
import pyro.poutine as poutine
from pyro.infer.elbo import ELBO
from pyro.infer.util import is_validation_enabled, validation_enabled
from pyro.poutine.util import prune_subsample_sites
from pyro.util import check_model_guide_match, check_site_shape, warn_if_nan

diff = x - y
if getattr(scale, "shape", ()) or getattr(mask, "shape", ()):
error = torch.einsum("nbe,nbe->nb", diff, diff)
else:
error = torch.einsum("nbe,nbe->n", diff, diff)

[docs]class EnergyDistance:
r"""
Posterior predictive energy distance [1,2] with optional Bayesian
regularization by the prior.

Let p(x,z)=p(z) p(x|z) be the model, q(z|x) be the guide.  Then given
data x and drawing an iid pair of samples :math:(Z,X) and
:math:(Z',X') (where Z is latent and X is the posterior predictive),

.. math ::

& Z \sim q(z|x); \quad X \sim p(x|Z) \\
& Z' \sim q(z|x); \quad X' \sim p(x|Z') \\
& loss = \mathbb E_X \|X-x\|^\beta
- \frac 1 2 \mathbb E_{X,X'}\|X-X'\|^\beta
- \lambda \mathbb E_Z \log p(Z)

This is a likelihood-free inference algorithm, and can be used for
likelihoods without tractable density functions. The :math:\beta energy
distance is a robust loss functions, and is well defined for any
distribution with finite fractional moment :math:\mathbb E[\|X\|^\beta].

This requires static model structure, a fully reparametrized guide, and
reparametrized likelihood distributions in the model. Model latent
distributions may be non-reparametrized.

**References**

[1] Gabor J. Szekely, Maria L. Rizzo (2003)
Energy Statistics: A Class of Statistics Based on Distances.
[2] Tilmann Gneiting, Adrian E. Raftery (2007)
Strictly Proper Scoring Rules, Prediction, and Estimation.
https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf

:param float beta: Exponent :math:\beta from [1,2]. The loss function is
strictly proper for distributions with finite :math:beta-absolute moment
:math:E[\|X\|^\beta]. Thus for heavy tailed distributions beta should
be small, e.g. for Cauchy distributions, :math:\beta<1 is strictly
proper. Defaults to 1. Must be in the open interval (0,2).
:param float prior_scale: Nonnegative scale for prior regularization.
Model parameters are trained only if this is positive.
If zero (default), then model log densities will not be computed
(guide log densities are never computed).
:param int num_particles: The number of particles/samples used to form the
gradient estimators. Must be at least 2.
:param int max_plate_nesting: Optional bound on max number of nested
:func:pyro.plate contexts. If omitted, this will guess a valid value
by running the (model,guide) pair once.
"""

def __init__(
self, beta=1.0, prior_scale=0.0, num_particles=2, max_plate_nesting=float("inf")
):
if not (isinstance(beta, (float, int)) and 0 < beta and beta < 2):
raise ValueError("Expected beta in (0,2), actual {}".format(beta))
if not (isinstance(prior_scale, (float, int)) and prior_scale >= 0):
raise ValueError("Expected prior_scale >= 0, actual {}".format(prior_scale))
if not (isinstance(num_particles, int) and num_particles >= 2):
raise ValueError(
"Expected num_particles >= 2, actual {}".format(num_particles)
)
self.beta = beta
self.prior_scale = prior_scale
self.num_particles = num_particles
self.vectorize_particles = True
self.max_plate_nesting = max_plate_nesting

def _pow(self, x):
if self.beta == 1:
return x.sqrt()  # cheaper than .pow()
return x.pow(self.beta / 2)

def _get_traces(self, model, guide, args, kwargs):
if self.max_plate_nesting == float("inf"):
with validation_enabled(False):  # Avoid calling .log_prob() when undefined.
# TODO factor this out as a stand-alone helper.
ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs)
vectorize = pyro.plate(
"num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting
)

# Trace the guide as in ELBO.
with poutine.trace() as tr, vectorize:
guide(*args, **kwargs)
guide_trace = tr.trace

# Trace the model, drawing posterior predictive samples.
with poutine.trace() as tr, poutine.uncondition():
with poutine.replay(trace=guide_trace), vectorize:
model(*args, **kwargs)
model_trace = tr.trace
for site in model_trace.nodes.values():
if site["type"] == "sample" and site["infer"].get("was_observed", False):
site["is_observed"] = True
if is_validation_enabled():
check_model_guide_match(model_trace, guide_trace, self.max_plate_nesting)

guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
if is_validation_enabled():
for site in guide_trace.nodes.values():
if site["type"] == "sample":
warn_if_nan(site["value"], site["name"])
if not getattr(site["fn"], "has_rsample", False):
raise ValueError(
"EnergyDistance requires fully reparametrized guides"
)
for trace in model_trace.nodes.values():
if site["type"] == "sample":
if site["is_observed"]:
warn_if_nan(site["value"], site["name"])
if not getattr(site["fn"], "has_rsample", False):
raise ValueError(
"EnergyDistance requires reparametrized likelihoods"
)

if self.prior_scale > 0:
model_trace.compute_log_prob(
site_filter=lambda name, site: not site["is_observed"]
)
if is_validation_enabled():
for site in model_trace.nodes.values():
if site["type"] == "sample":
if not site["is_observed"]:
check_site_shape(site, self.max_plate_nesting)

return guide_trace, model_trace

[docs]    def __call__(self, model, guide, *args, **kwargs):
"""
Computes the surrogate loss that can be differentiated with autograd
to produce gradient estimates for the model and guide parameters.
"""
guide_trace, model_trace = self._get_traces(model, guide, args, kwargs)

# Extract observations and posterior predictive samples.
data = OrderedDict()
samples = OrderedDict()
for name, site in model_trace.nodes.items():
if site["type"] == "sample" and site["is_observed"]:
data[name] = site["infer"]["obs"]
samples[name] = site["value"]
assert list(data.keys()) == list(samples.keys())
if not data:
raise ValueError("Found no observations")

# Compute energy distance from mean average error and generalized entropy.
squared_error = []  # E[ (X - x)^2 ]
squared_entropy = []  # E[ (X - X')^2 ]
prototype = next(iter(data.values()))
pairs = (
prototype.new_ones(self.num_particles, self.num_particles)
.tril(-1)
.nonzero(as_tuple=False)
)
for name, obs in data.items():
sample = samples[name]
scale = model_trace.nodes[name]["scale"]

# Flatten to subshapes of (num_particles, batch_size, event_size).
event_dim = model_trace.nodes[name]["fn"].event_dim
batch_shape = obs.shape[: obs.dim() - event_dim]
event_shape = obs.shape[obs.dim() - event_dim :]
if getattr(scale, "shape", ()):
scale = scale.expand(batch_shape).reshape(-1)
obs = obs.reshape(batch_shape.numel(), event_shape.numel())
sample = sample.reshape(
self.num_particles, batch_shape.numel(), event_shape.numel()
)

squared_entropy.append(
)

error = self._pow(squared_error).mean()  # E[ ||X-x||^beta ]
entropy = self._pow(squared_entropy).mean()  # E[ ||X-X'||^beta ]
energy = error - 0.5 * entropy

# Compute prior.
log_prior = 0
if self.prior_scale > 0:
for site in model_trace.nodes.values():
if site["type"] == "sample" and not site["is_observed"]:
log_prior = log_prior + site["log_prob_sum"]

# Compute final loss.
loss = energy - self.prior_scale * log_prior
warn_if_nan(loss, "loss")
return loss

[docs]    def loss(self, *args, **kwargs):
"""
Not implemented. Added for compatibility with unit tests only.
"""
raise NotImplementedError("EnergyDistance implements only surrogate loss")