Optimal Experiment Design¶
The pyro.contrib.oed
module provides tools to create optimal experiment
designs for pyro models. In particular, it provides estimators for the
expected information gain (EIG) criterion.
To estimate the EIG for a particular design, use:
def model(design):
...
# Select an appropriate EIG estimator, such as
eig = vnmc_eig(model, design, ...)
EIG can then be maximised using existing optimisers in pyro.optim
.
Expected Information Gain¶

laplace_eig
(model, design, observation_labels, target_labels, guide, loss, optim, num_steps, final_num_samples, y_dist=None, eig=True, **prior_entropy_kwargs)[source]¶ Estimates the expected information gain (EIG) by making repeated Laplace approximations to the posterior.
Parameters:  model (function) – Pyro stochastic function taking design as only argument.
 design (torch.Tensor) – Tensor of possible designs.
 observation_labels (list) – labels of sample sites to be regarded as observables.
 target_labels (list) – labels of sample sites to be regarded as latent variables of interest, i.e. the sites that we wish to gain information about.
 guide (function) – Pyro stochastic function corresponding to model.
 loss – a Pyro loss such as pyro.infer.Trace_ELBO().differentiable_loss.
 optim – optimizer for the loss
 num_steps (int) – Number of gradient steps to take per sampled pseudoobservation.
 final_num_samples (int) – Number of y samples (pseudoobservations) to take.
 y_dist – Distribution to sample y from if None we use the Bayesian marginal distribution.
 eig (bool) – Whether to compute the EIG or the average posterior entropy (APE). The EIG is given by EIG = prior entropy  APE. If True, the prior entropy will be estimated analytically, or by Monte Carlo as appropriate for the model. If False the APE is returned.
 prior_entropy_kwargs (dict) – parameters for estimating the prior entropy: num_prior_samples indicating the number of samples for a MC estimate of prior entropy, and mean_field indicating if an analytic form for a meanfield prior should be tried.
Returns: EIG estimate
Return type:

vi_eig
(model, design, observation_labels, target_labels, vi_parameters, is_parameters, y_dist=None, eig=True, **prior_entropy_kwargs)[source]¶ Estimates the expected information gain (EIG) using variational inference (VI).
The APE is defined as
\(APE(d)=E_{Y\sim p(y\theta, d)}[H(p(\thetaY, d))]\)where \(H[p(x)]\) is the differential entropy. The APE is related to expected information gain (EIG) by the equation
\(EIG(d)=H[p(\theta)]APE(d)\)in particular, minimising the APE is equivalent to maximising EIG.
Parameters:  model (function) – A pyro model accepting design as only argument.
 design (torch.Tensor) – Tensor representation of design
 observation_labels (list) – A subset of the sample sites present in model. These sites are regarded as future observations and other sites are regarded as latent variables over which a posterior is to be inferred.
 target_labels (list) – A subset of the sample sites over which the posterior entropy is to be measured.
 vi_parameters (dict) – Variational inference parameters which should include:
optim: an instance of
pyro.Optim
, guide: a guide function compatible with model, num_steps: the number of VI steps to make, and loss: the loss function to use for VI  is_parameters (dict) – Importance sampling parameters for the marginal distribution of \(Y\). May include num_samples: the number of samples to draw from the marginal.
 y_dist (pyro.distributions.Distribution) – (optional) the distribution assumed for the response variable \(Y\)
 eig (bool) – Whether to compute the EIG or the average posterior entropy (APE). The EIG is given by EIG = prior entropy  APE. If True, the prior entropy will be estimated analytically, or by Monte Carlo as appropriate for the model. If False the APE is returned.
 prior_entropy_kwargs (dict) – parameters for estimating the prior entropy: num_prior_samples indicating the number of samples for a MC estimate of prior entropy, and mean_field indicating if an analytic form for a meanfield prior should be tried.
Returns: EIG estimate
Return type: torch.Tensor

nmc_eig
(model, design, observation_labels, target_labels=None, N=100, M=10, M_prime=None, independent_priors=False)[source]¶  Nested Monte Carlo estimate of the expected information
gain (EIG). The estimate is, when there are not any random effects,
\[\frac{1}{N}\sum_{n=1}^N \log p(y_n  \theta_n, d)  \frac{1}{N}\sum_{n=1}^N \log \left(\frac{1}{M}\sum_{m=1}^M p(y_n  \theta_m, d)\right)\]The estimate is, in the presence of random effects,
\[\frac{1}{N}\sum_{n=1}^N \log \left(\frac{1}{M'}\sum_{m=1}^{M'} p(y_n  \theta_n, \widetilde{\theta}_{nm}, d)\right) \frac{1}{N}\sum_{n=1}^N \log \left(\frac{1}{M}\sum_{m=1}^{M} p(y_n  \theta_m, \widetilde{\theta}_{m}, d)\right)\]The latter form is used when M_prime != None.
param function model: A pyro model accepting design as only argument. param torch.Tensor design: Tensor representation of design param list observation_labels: A subset of the sample sites present in model. These sites are regarded as future observations and other sites are regarded as latent variables over which a posterior is to be inferred. param list target_labels: A subset of the sample sites over which the posterior entropy is to be measured. param int N: Number of outer expectation samples. param int M: Number of inner expectation samples for p(yd). param int M_prime: Number of samples for p(y  theta, d) if required. param bool independent_priors: Only used when M_prime is not None. Indicates whether the prior distributions for the target variables and the nuisance variables are independent. In this case, it is not necessary to sample the targets conditional on the nuisance variables. return: EIG estimate rtype: torch.Tensor

donsker_varadhan_eig
(model, design, observation_labels, target_labels, num_samples, num_steps, T, optim, return_history=False, final_design=None, final_num_samples=None)[source]¶ DonskerVaradhan estimate of the expected information gain (EIG).
The DonskerVaradhan representation of EIG is
\[\sup_T E_{p(y, \theta  d)}[T(y, \theta)]  \log E_{p(yd)p(\theta)}[\exp(T(\bar{y}, \bar{\theta}))]\]where \(T\) is any (measurable) function.
This methods optimises the loss function over a prespecified class of functions T.
Parameters:  model (function) – A pyro model accepting design as only argument.
 design (torch.Tensor) – Tensor representation of design
 observation_labels (list) – A subset of the sample sites present in model. These sites are regarded as future observations and other sites are regarded as latent variables over which a posterior is to be inferred.
 target_labels (list) – A subset of the sample sites over which the posterior entropy is to be measured.
 num_samples (int) – Number of samples per iteration.
 num_steps (int) – Number of optimisation steps.
 or torch.nn.Module T (function) – optimisable function T for use in the DonskerVaradhan loss function.
 optim (pyro.optim.Optim) – Optimiser to use.
 return_history (bool) – If True, also returns a tensor giving the loss function at each step of the optimisation.
 final_design (torch.Tensor) – The final design tensor to evaluate at. If None, uses design.
 final_num_samples (int) – The number of samples to use at the final evaluation, If None, uses `num_samples.
Returns: EIG estimate, optionally includes full optimisatio history
Return type: torch.Tensor or tuple

posterior_eig
(model, design, observation_labels, target_labels, num_samples, num_steps, guide, optim, return_history=False, final_design=None, final_num_samples=None, eig=True, prior_entropy_kwargs={}, *args, **kwargs)[source]¶ Posterior estimate of expected information gain (EIG) computed from the average posterior entropy (APE) using EIG = prior entropy  APE. See [1] for full details.
The posterior representation of APE is
\(sup_{q}E_{p(y, \theta  d)}[\log q(\theta  y, d)]\)where \(q\) is any distribution on \(\theta\).
This method optimises the loss over a given guide family guide representing \(q\).
[1] Foster, Adam, et al. “Variational Bayesian Optimal Experimental Design.” arXiv preprint arXiv:1903.05480 (2019).
Parameters:  model (function) – A pyro model accepting design as only argument.
 design (torch.Tensor) – Tensor representation of design
 observation_labels (list) – A subset of the sample sites present in model. These sites are regarded as future observations and other sites are regarded as latent variables over which a posterior is to be inferred.
 target_labels (list) – A subset of the sample sites over which the posterior entropy is to be measured.
 num_samples (int) – Number of samples per iteration.
 num_steps (int) – Number of optimisation steps.
 guide (function) – guide family for use in the (implicit) posterior estimation. The parameters of guide are optimised to maximise the posterior objective.
 optim (pyro.optim.Optim) – Optimiser to use.
 return_history (bool) – If True, also returns a tensor giving the loss function at each step of the optimisation.
 final_design (torch.Tensor) – The final design tensor to evaluate at. If None, uses design.
 final_num_samples (int) – The number of samples to use at the final evaluation, If None, uses `num_samples.
 eig (bool) – Whether to compute the EIG or the average posterior entropy (APE). The EIG is given by EIG = prior entropy  APE. If True, the prior entropy will be estimated analytically, or by Monte Carlo as appropriate for the model. If False the APE is returned.
 prior_entropy_kwargs (dict) – parameters for estimating the prior entropy: num_prior_samples indicating the number of samples for a MC estimate of prior entropy, and mean_field indicating if an analytic form for a meanfield prior should be tried.
Returns: EIG estimate, optionally includes full optimisation history
Return type: torch.Tensor or tuple

marginal_eig
(model, design, observation_labels, target_labels, num_samples, num_steps, guide, optim, return_history=False, final_design=None, final_num_samples=None)[source]¶ Estimate EIG by estimating the marginal entropy \(p(yd)\). See [1] for full details.
The marginal representation of EIG is
\(inf_{q}E_{p(y, \theta  d)}\left[\log \frac{p(y  \theta, d)}{q(y  d)} \right]\)where \(q\) is any distribution on \(y\).
Warning
this method does not estimate the correct quantity in the presence of random effects.
[1] Foster, Adam, et al. “Variational Bayesian Optimal Experimental Design.” arXiv preprint arXiv:1903.05480 (2019).
Parameters:  model (function) – A pyro model accepting design as only argument.
 design (torch.Tensor) – Tensor representation of design
 observation_labels (list) – A subset of the sample sites present in model. These sites are regarded as future observations and other sites are regarded as latent variables over which a posterior is to be inferred.
 target_labels (list) – A subset of the sample sites over which the posterior entropy is to be measured.
 num_samples (int) – Number of samples per iteration.
 num_steps (int) – Number of optimisation steps.
 guide (function) – guide family for use in the marginal estimation. The parameters of guide are optimised to maximise the loglikelihood objective.
 optim (pyro.optim.Optim) – Optimiser to use.
 return_history (bool) – If True, also returns a tensor giving the loss function at each step of the optimisation.
 final_design (torch.Tensor) – The final design tensor to evaluate at. If None, uses design.
 final_num_samples (int) – The number of samples to use at the final evaluation, If None, uses `num_samples.
Returns: EIG estimate, optionally includes full optimisation history
Return type: torch.Tensor or tuple

lfire_eig
(model, design, observation_labels, target_labels, num_y_samples, num_theta_samples, num_steps, classifier, optim, return_history=False, final_design=None, final_num_samples=None)[source]¶ Estimates the EIG using the method of LikelihoodFree Inference by Ratio Estimation (LFIRE) as in [1]. LFIRE is run separately for several samples of \(\theta\).
[1] Kleinegesse, Steven, and Michael Gutmann. “Efficient Bayesian Experimental Design for Implicit Models.” arXiv preprint arXiv:1810.09912 (2018).
Parameters:  model (function) – A pyro model accepting design as only argument.
 design (torch.Tensor) – Tensor representation of design
 observation_labels (list) – A subset of the sample sites present in model. These sites are regarded as future observations and other sites are regarded as latent variables over which a posterior is to be inferred.
 target_labels (list) – A subset of the sample sites over which the posterior entropy is to be measured.
 num_y_samples (int) – Number of samples to take in \(y\) for each \(\theta\).
 num_steps (int) – Number of optimisation steps.
 classifier (function) – a Pytorch or Pyro classifier used to distinguish between samples of \(y\) under \(p(yd)\) and samples under \(p(y\theta,d)\) for some \(\theta\).
 optim (pyro.optim.Optim) – Optimiser to use.
 return_history (bool) – If True, also returns a tensor giving the loss function at each step of the optimisation.
 final_design (torch.Tensor) – The final design tensor to evaluate at. If None, uses design.
 final_num_samples (int) – The number of samples to use at the final evaluation, If None, uses `num_samples.
Param: int num_theta_samples: Number of initial samples in \(\theta\) to take. The likelihood ratio is estimated by LFIRE for each sample.
Returns: EIG estimate, optionally includes full optimisation history
Return type: torch.Tensor or tuple

vnmc_eig
(model, design, observation_labels, target_labels, num_samples, num_steps, guide, optim, return_history=False, final_design=None, final_num_samples=None)[source]¶ Estimates the EIG using Variational Nested Monte Carlo (VNMC). The VNMC estimate [1] is
\[\frac{1}{N}\sum_{n=1}^N \left[ \log p(y_n  \theta_n, d)  \log \left(\frac{1}{M}\sum_{m=1}^M \frac{p(\theta_{mn})p(y_n  \theta_{mn}, d)} {q(\theta_{mn}  y_n)} \right) \right]\]where \(q(\theta  y)\) is the learned variational posterior approximation and \(\theta_n, y_n \sim p(\theta, y  d)\), \(\theta_{mn} \sim q(\thetay=y_n)\).
As \(N \to \infty\) this is an upper bound on EIG. We minimise this upper bound by stochastic gradient descent.
[1] Foster, Adam, et al. “Variational Bayesian Optimal Experimental Design.” arXiv preprint arXiv:1903.05480 (2019).
Parameters:  model (function) – A pyro model accepting design as only argument.
 design (torch.Tensor) – Tensor representation of design
 observation_labels (list) – A subset of the sample sites present in model. These sites are regarded as future observations and other sites are regarded as latent variables over which a posterior is to be inferred.
 target_labels (list) – A subset of the sample sites over which the posterior entropy is to be measured.
 num_samples (tuple) – Number of (\(N, M\)) samples per iteration.
 num_steps (int) – Number of optimisation steps.
 guide (function) – guide family for use in the posterior estimation. The parameters of guide are optimised to minimise the VNMC upper bound.
 optim (pyro.optim.Optim) – Optimiser to use.
 return_history (bool) – If True, also returns a tensor giving the loss function at each step of the optimisation.
 final_design (torch.Tensor) – The final design tensor to evaluate at. If None, uses design.
 final_num_samples (tuple) – The number of (\(N, M\)) samples to use at the final evaluation, If None, uses `num_samples.
Returns: EIG estimate, optionally includes full optimisation history
Return type: torch.Tensor or tuple