# Source code for pyro.contrib.oed.eig

from __future__ import absolute_import, division, print_function

import math
import torch

import pyro
from pyro import poutine
from pyro.contrib.autoguide import mean_field_guide_entropy
from pyro.contrib.oed.search import Search
from pyro.contrib.util import lexpand
from pyro.infer import EmpiricalMarginal, Importance, SVI
from pyro.util import torch_isnan, torch_isinf

[docs]def vi_ape(model, design, observation_labels, target_labels,
vi_parameters, is_parameters, y_dist=None):
"""Estimates the average posterior entropy (APE) loss function using
variational inference (VI).

The APE loss function estimated by this method is defined as

:math:APE(d)=E_{Y\\sim p(y|\\theta, d)}[H(p(\\theta|Y, d))]

where :math:H[p(x)] is the differential entropy
<https://en.wikipedia.org/wiki/Differential_entropy>_.
The APE is related to expected information gain (EIG) by the equation

:math:EIG(d)=H[p(\\theta)]-APE(d)

in particular, minimising the APE is equivalent to maximising EIG.

:param function model: A pyro model accepting design as only argument.
:param torch.Tensor design: Tensor representation of design
:param list observation_labels: A subset of the sample sites
present in model. These sites are regarded as future observations
and other sites are regarded as latent variables over which a
posterior is to be inferred.
:param list target_labels: A subset of the sample sites over which the posterior
entropy is to be measured.
:param dict vi_parameters: Variational inference parameters which should include:
optim: an instance of :class:pyro.Optim, guide: a guide function
compatible with model, num_steps: the number of VI steps to make,
and loss: the loss function to use for VI
:param dict is_parameters: Importance sampling parameters for the
marginal distribution of :math:Y. May include num_samples: the number
of samples to draw from the marginal.
:param pyro.distributions.Distribution y_dist: (optional) the distribution
assumed for the response variable :math:Y
:return: Loss function estimate
:rtype: torch.Tensor

"""

if isinstance(observation_labels, str):
observation_labels = [observation_labels]
if target_labels is not None and isinstance(target_labels, str):
target_labels = [target_labels]

def posterior_entropy(y_dist, design):
# Important that y_dist is sampled *within* the function
y = pyro.sample("conditioning_y", y_dist)
y_dict = {label: y[i, ...] for i, label in enumerate(observation_labels)}
conditioned_model = pyro.condition(model, data=y_dict)
SVI(conditioned_model, **vi_parameters).run(design)
# Recover the entropy
return mean_field_guide_entropy(vi_parameters["guide"], [design], whitelist=target_labels)

if y_dist is None:
y_dist = EmpiricalMarginal(Importance(model, **is_parameters).run(design),
sites=observation_labels)

# Calculate the expected posterior entropy under this distn of y
loss_dist = EmpiricalMarginal(Search(posterior_entropy).run(y_dist, design))
loss = loss_dist.mean

return loss

[docs]def naive_rainforth_eig(model, design, observation_labels, target_labels=None,
N=100, M=10, M_prime=None):
"""
Naive Rainforth (i.e. Nested Monte Carlo) estimate of the expected information
gain (EIG). The estimate is

.. math::

\\frac{1}{N}\\sum_{n=1}^N \\log p(y_n | \\theta_n, d) -
\\log \\left(\\frac{1}{M}\\sum_{m=1}^M p(y_n | \\theta_m, d)\\right)

Monte Carlo estimation is attempted for the :math:\\log p(y | \\theta, d) term if
the parameter M_prime is passed. Otherwise, it is assumed that that :math:\\log p(y | \\theta, d)
can safely be read from the model itself.

:param function model: A pyro model accepting design as only argument.
:param torch.Tensor design: Tensor representation of design
:param list observation_labels: A subset of the sample sites
present in model. These sites are regarded as future observations
and other sites are regarded as latent variables over which a
posterior is to be inferred.
:param list target_labels: A subset of the sample sites over which the posterior
entropy is to be measured.
:param int N: Number of outer expectation samples.
:param int M: Number of inner expectation samples for p(y|d).
:param int M_prime: Number of samples for p(y | theta, d) if required.
:return: EIG estimate
:rtype: torch.Tensor
"""

if isinstance(observation_labels, str):
observation_labels = [observation_labels]
if isinstance(target_labels, str):
target_labels = [target_labels]

# Take N samples of the model
expanded_design = lexpand(design, N)
trace = poutine.trace(model).get_trace(expanded_design)
trace.compute_log_prob()

if M_prime is not None:
y_dict = {l: lexpand(trace.nodes[l]["value"], M_prime) for l in observation_labels}
theta_dict = {l: lexpand(trace.nodes[l]["value"], M_prime) for l in target_labels}
theta_dict.update(y_dict)
# Resample M values of u and compute conditional probabilities
conditional_model = pyro.condition(model, data=theta_dict)
# Not acceptable to use (M_prime, 1) here - other variables may occur after
# theta, so need to be sampled conditional upon it
reexpanded_design = lexpand(design, M_prime, N)
retrace = poutine.trace(conditional_model).get_trace(reexpanded_design)
retrace.compute_log_prob()
conditional_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \
- math.log(M_prime)
else:
# This assumes that y are independent conditional on theta
# Furthermore assume that there are no other variables besides theta
conditional_lp = sum(trace.nodes[l]["log_prob"] for l in observation_labels)

y_dict = {l: lexpand(trace.nodes[l]["value"], M) for l in observation_labels}
# Resample M values of theta and compute conditional probabilities
conditional_model = pyro.condition(model, data=y_dict)
# Using (M, 1) instead of (M, N) - acceptable to re-use thetas between ys because
# theta comes before y in graphical model
reexpanded_design = lexpand(design, M, 1)
retrace = poutine.trace(conditional_model).get_trace(reexpanded_design)
retrace.compute_log_prob()
marginal_lp = sum(retrace.nodes[l]["log_prob"] for l in observation_labels).logsumexp(0) \
- math.log(M)

return (conditional_lp - marginal_lp).sum(0)/N

[docs]def donsker_varadhan_eig(model, design, observation_labels, target_labels,
num_samples, num_steps, T, optim, return_history=False,
final_design=None, final_num_samples=None):
"""
Donsker-Varadhan estimate of the expected information gain (EIG).

The Donsker-Varadhan representation of EIG is

.. math::

\\sup_T E_{p(y, \\theta | d)}[T(y, \\theta)] - \\log E_{p(y|d)p(\\theta)}[\\exp(T(\\bar{y}, \\bar{\\theta}))]

where :math:T is any (measurable) function.

This methods optimises the loss function over a pre-specified class of
functions T.

:param function model: A pyro model accepting design as only argument.
:param torch.Tensor design: Tensor representation of design
:param list observation_labels: A subset of the sample sites
present in model. These sites are regarded as future observations
and other sites are regarded as latent variables over which a
posterior is to be inferred.
:param list target_labels: A subset of the sample sites over which the posterior
entropy is to be measured.
:param int num_samples: Number of samples per iteration.
:param int num_steps: Number of optimisation steps.
:param function or torch.nn.Module T: optimisable function T for use in the
:param pyro.optim.Optim optim: Optimiser to use.
:param bool return_history: If True, also returns a tensor giving the loss function
at each step of the optimisation.
:param torch.Tensor final_design: The final design tensor to evaluate at. If None, uses
design.
:param int final_num_samples: The number of samples to use at the final evaluation, If None,
uses num_samples.
:return: EIG estimate, optionally includes full optimisatio history
:rtype: torch.Tensor or tuple
"""
if isinstance(observation_labels, str):
observation_labels = [observation_labels]
if isinstance(target_labels, str):
target_labels = [target_labels]
loss = donsker_varadhan_loss(model, T, observation_labels, target_labels)
return opt_eig_ape_loss(design, loss, num_samples, num_steps, optim, return_history,
final_design, final_num_samples)

[docs]def barber_agakov_ape(model, design, observation_labels, target_labels,
num_samples, num_steps, guide, optim, return_history=False,
final_design=None, final_num_samples=None):
"""
Barber-Agakov estimate of average posterior entropy (APE).

The Barber-Agakov representation of APE is

:math:sup_{q}E_{p(y, \\theta | d)}[\\log q(\\theta | y, d)]

where :math:q is any distribution on :math:\\theta.

This method optimises the loss over a given guide family guide
representing :math:q.

:param function model: A pyro model accepting design as only argument.
:param torch.Tensor design: Tensor representation of design
:param list observation_labels: A subset of the sample sites
present in model. These sites are regarded as future observations
and other sites are regarded as latent variables over which a
posterior is to be inferred.
:param list target_labels: A subset of the sample sites over which the posterior
entropy is to be measured.
:param int num_samples: Number of samples per iteration.
:param int num_steps: Number of optimisation steps.
:param function guide: guide family for use in the (implicit) posterior estimation.
The parameters of guide are optimised to maximise the Barber-Agakov
objective.
:param pyro.optim.Optim optim: Optimiser to use.
:param bool return_history: If True, also returns a tensor giving the loss function
at each step of the optimisation.
:param torch.Tensor final_design: The final design tensor to evaluate at. If None, uses
design.
:param int final_num_samples: The number of samples to use at the final evaluation, If None,
uses num_samples.
:return: EIG estimate, optionally includes full optimisatio history
:rtype: torch.Tensor or tuple
"""
if isinstance(observation_labels, str):
observation_labels = [observation_labels]
if isinstance(target_labels, str):
target_labels = [target_labels]
loss = barber_agakov_loss(model, guide, observation_labels, target_labels)
return opt_eig_ape_loss(design, loss, num_samples, num_steps, optim, return_history,
final_design, final_num_samples)

def opt_eig_ape_loss(design, loss_fn, num_samples, num_steps, optim, return_history=False,
final_design=None, final_num_samples=None):

if final_design is None:
final_design = design
if final_num_samples is None:
final_num_samples = num_samples

params = None
history = []
for step in range(num_steps):
if params is not None:
agg_loss, loss = loss_fn(design, num_samples)
agg_loss.backward()
if return_history:
history.append(loss)
params = [value.unconstrained()
for value in pyro.get_param_store().values()]
optim(params)
_, loss = loss_fn(final_design, final_num_samples)
if return_history:
else:
return loss

def donsker_varadhan_loss(model, T, observation_labels, target_labels):

ewma_log = EwmaLog(alpha=0.90)

try:
pyro.module("T", T)
except AssertionError:
pass

def loss_fn(design, num_particles):

expanded_design = lexpand(design, num_particles)

# Unshuffled data
unshuffled_trace = poutine.trace(model).get_trace(expanded_design)
y_dict = {l: unshuffled_trace.nodes[l]["value"] for l in observation_labels}

# Shuffled data
# Not actually shuffling, resimulate for safety
conditional_model = pyro.condition(model, data=y_dict)
shuffled_trace = poutine.trace(conditional_model).get_trace(expanded_design)

T_joint = T(expanded_design, unshuffled_trace, observation_labels,
target_labels)
T_independent = T(expanded_design, shuffled_trace, observation_labels,
target_labels)

joint_expectation = T_joint.sum(0)/num_particles

A = T_independent - math.log(num_particles)
s, _ = torch.max(A, dim=0)
independent_expectation = s + ewma_log((A - s).exp().sum(dim=0), s)

loss = joint_expectation - independent_expectation
# Switch sign, sum over batch dimensions for scalar loss
agg_loss = -loss.sum()
return agg_loss, loss

return loss_fn

def barber_agakov_loss(model, guide, observation_labels, target_labels):

def loss_fn(design, num_particles):

expanded_design = lexpand(design, num_particles)

# Sample from p(y, theta | d)
trace = poutine.trace(model).get_trace(expanded_design)
y_dict = {l: trace.nodes[l]["value"] for l in observation_labels}
theta_dict = {l: trace.nodes[l]["value"] for l in target_labels}

# Run through q(theta | y, d)
conditional_guide = pyro.condition(guide, data=theta_dict)
cond_trace = poutine.trace(conditional_guide).get_trace(
y_dict, expanded_design, observation_labels, target_labels)
cond_trace.compute_log_prob()

loss = -sum(cond_trace.nodes[l]["log_prob"] for l in target_labels).sum(0)/num_particles
agg_loss = loss.sum()
return agg_loss, loss

return loss_fn

@staticmethod
def forward(ctx, input, ewma):
ctx.save_for_backward(ewma)
return input.log()

@staticmethod
ewma, = ctx.saved_tensors
return grad_output / ewma, None

_ewma_log_fn = _EwmaLogFn.apply

[docs]class EwmaLog(object):
"""Logarithm function with exponentially weighted moving average

For input inputs this function return :code:inputs.log(). However, it
computes the gradient as

:math:\\frac{\\sum_{t=0}^{T-1} \\alpha^t}{\\sum_{t=0}^{T-1} \\alpha^t x_{T-t}}

where :math:x_t are historical input values passed to this function,
:math:x_T being the most recently seen value.

This gradient may help with numerical stability when the sequence of
inputs to the function form a convergent sequence.
"""

def __init__(self, alpha):
self.alpha = alpha
self.ewma = 0.
self.n = 0
self.s = 0.

def __call__(self, inputs, s, dim=0, keepdim=False):
"""Updates the moving average, and returns :code:inputs.log().
"""
self.n += 1
if torch_isnan(self.ewma) or torch_isinf(self.ewma):
ewma = inputs
else:
ewma = inputs * (1. - self.alpha) / (1 - self.alpha**self.n) \
+ torch.exp(self.s - s) * self.ewma \
* (self.alpha - self.alpha**self.n) / (1 - self.alpha**self.n)
self.ewma = ewma.detach()
self.s = s.detach()
return _ewma_log_fn(inputs, ewma)