SVI

class SVI(model, guide, optim, loss, loss_and_grads=None, num_samples=0, num_steps=0, **kwargs)[source]

Bases: pyro.infer.abstract_infer.TracePosterior

Parameters
  • model – the model (callable containing Pyro primitives)

  • guide – the guide (callable containing Pyro primitives)

  • optim (PyroOptim) – a wrapper a for a PyTorch optimizer

  • loss (pyro.infer.elbo.ELBO) – an instance of a subclass of ELBO. Pyro provides three built-in losses: Trace_ELBO, TraceGraph_ELBO, and TraceEnum_ELBO. See the ELBO docs to learn how to implement a custom loss.

  • num_samples – (DEPRECATED) the number of samples for Monte Carlo posterior approximation

  • num_steps – (DEPRECATED) the number of optimization steps to take in run()

A unified interface for stochastic variational inference in Pyro. The most commonly used loss is loss=Trace_ELBO(). See the tutorial SVI Part I for a discussion.

evaluate_loss(*args, **kwargs)[source]
Returns

estimate of the loss

Return type

float

Evaluate the loss function. Any args or kwargs are passed to the model and guide.

run(*args, **kwargs)[source]

Warning

This method is deprecated, and will be removed in a future release. For inference, use step() directly, and for predictions, use the Predictive class.

step(*args, **kwargs)[source]
Returns

estimate of the loss

Return type

float

Take a gradient step on the loss function (and any auxiliary loss functions generated under the hood by loss_and_grads). Any args or kwargs are passed to the model and guide

ELBO

class ELBOModule(model: torch.nn.modules.module.Module, guide: torch.nn.modules.module.Module, elbo: pyro.infer.elbo.ELBO)[source]

Bases: torch.nn.modules.module.Module

forward(*args, **kwargs)[source]
training: bool
class ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: object

ELBO is the top-level interface for stochastic variational inference via optimization of the evidence lower bound.

Most users will not interact with this base class ELBO directly; instead they will create instances of derived classes: Trace_ELBO, TraceGraph_ELBO, or TraceEnum_ELBO.

Note

Derived classes now provide a more idiomatic PyTorch interface via __call__() for (model, guide) pairs that are Module s, which is useful for integrating Pyro’s variational inference tooling with standard PyTorch interfaces like Optimizer s and the large ecosystem of libraries like PyTorch Lightning and the PyTorch JIT that work with these interfaces:

model = Model()
guide = pyro.infer.autoguide.AutoNormal(model)

elbo_ = pyro.infer.Trace_ELBO(num_particles=10)

# Fix the model/guide pair
elbo = elbo_(model, guide)

# perform any data-dependent initialization
elbo(data)

optim = torch.optim.Adam(elbo.parameters(), lr=0.001)

for _ in range(100):
    optim.zero_grad()
    loss = elbo(data)
    loss.backward()
    optim.step()

Note that Pyro’s global parameter store may cause this new interface to behave unexpectedly relative to standard PyTorch when working with PyroModule s.

Users are therefore strongly encouraged to use this interface in conjunction with pyro.settings.set(module_local_params=True) which will override the default implicit sharing of parameters across PyroModule instances.

Parameters
  • num_particles – The number of particles/samples used to form the ELBO (gradient) estimators.

  • max_plate_nesting (int) – Optional bound on max number of nested pyro.plate() contexts. This is only required when enumerating over sample sites in parallel, e.g. if a site sets infer={"enumerate": "parallel"}. If omitted, ELBO may guess a valid value by running the (model,guide) pair once, however this guess may be incorrect if model or guide structure is dynamic.

  • vectorize_particles (bool) – Whether to vectorize the ELBO computation over num_particles. Defaults to False. This requires static structure in model and guide.

  • strict_enumeration_warning (bool) – Whether to warn about possible misuse of enumeration, i.e. that pyro.infer.traceenum_elbo.TraceEnum_ELBO is used iff there are enumerated sample sites.

  • ignore_jit_warnings (bool) – Flag to ignore warnings from the JIT tracer. When this is True, all torch.jit.TracerWarning will be ignored. Defaults to False.

  • jit_options (bool) – Optional dict of options to pass to torch.jit.trace() , e.g. {"check_trace": True}.

  • retain_graph (bool) – Whether to retain autograd graph during an SVI step. Defaults to None (False).

  • tail_adaptive_beta (float) – Exponent beta with -1.0 <= beta < 0.0 for use with TraceTailAdaptive_ELBO.

References

[1] Automated Variational Inference in Probabilistic Programming David Wingate, Theo Weber

[2] Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei

class Trace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.elbo.ELBO

A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and [2]. There are no restrictions on the dependency structure of the model or the guide. The gradient estimator includes partial Rao-Blackwellization for reducing the variance of the estimator when non-reparameterizable random variables are present. The Rao-Blackwellization is partial in that it only uses conditional independence information that is marked by plate contexts. For more fine-grained Rao-Blackwellization, see TraceGraph_ELBO.

References

[1] Automated Variational Inference in Probabilistic Programming,

David Wingate, Theo Weber

[2] Black Box Variational Inference,

Rajesh Ranganath, Sean Gerrish, David M. Blei

loss(model, guide, *args, **kwargs)[source]
Returns

returns an estimate of the ELBO

Return type

float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

differentiable_loss(model, guide, *args, **kwargs)[source]

Computes the surrogate loss that can be differentiated with autograd to produce gradient estimates for the model and guide parameters

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns

returns an estimate of the ELBO

Return type

float

Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators.

class JitTrace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.trace_elbo.Trace_ELBO

Like Trace_ELBO but uses pyro.ops.jit.compile() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.

  • Models must not depend on any global data (except the param store).

  • All model inputs that are tensors must be passed in via *args.

  • All model inputs that are not tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs.

loss_and_surrogate_loss(model, guide, *args, **kwargs)[source]
differentiable_loss(model, guide, *args, **kwargs)[source]
loss_and_grads(model, guide, *args, **kwargs)[source]
class TrackNonReparam[source]

Bases: pyro.poutine.messenger.Messenger

Track non-reparameterizable sample sites.

References:

  1. Nonstandard Interpretations of Probabilistic Programs for Efficient Inference,

    David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind

Example:

>>> import torch
>>> import pyro
>>> import pyro.distributions as dist
>>> from pyro.infer.tracegraph_elbo import TrackNonReparam
>>> from pyro.ops.provenance import get_provenance
>>> from pyro.poutine import trace

>>> def model():
...     probs_a = torch.tensor([0.3, 0.7])
...     probs_b = torch.tensor([[0.1, 0.9], [0.8, 0.2]])
...     probs_c = torch.tensor([[0.5, 0.5], [0.6, 0.4]])
...     a = pyro.sample("a", dist.Categorical(probs_a))
...     b = pyro.sample("b", dist.Categorical(probs_b[a]))
...     pyro.sample("c", dist.Categorical(probs_c[b]), obs=torch.tensor(0))

>>> with TrackNonReparam():
...     model_tr = trace(model).get_trace()
>>> model_tr.compute_log_prob()

>>> print(get_provenance(model_tr.nodes["a"]["log_prob"]))  
frozenset({'a'})
>>> print(get_provenance(model_tr.nodes["b"]["log_prob"]))  
frozenset({'b', 'a'})
>>> print(get_provenance(model_tr.nodes["c"]["log_prob"]))  
frozenset({'b', 'a'})
class TraceGraph_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.elbo.ELBO

A TraceGraph implementation of ELBO-based SVI. The gradient estimator is constructed along the lines of reference [1] specialized to the case of the ELBO. It supports arbitrary dependency structure for the model and guide as well as baselines for non-reparameterizable random variables. Fine-grained conditional dependency information as recorded in the Trace is used to reduce the variance of the gradient estimator. In particular provenance tracking [3] is used to find the cost terms that depend on each non-reparameterizable sample site.

References

[1] Gradient Estimation Using Stochastic Computation Graphs,

John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel

[2] Neural Variational Inference and Learning in Belief Networks

Andriy Mnih, Karol Gregor

[3] Nonstandard Interpretations of Probabilistic Programs for Efficient Inference,

David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind

loss(model, guide, *args, **kwargs)[source]
Returns

returns an estimate of the ELBO

Return type

float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns

returns an estimate of the ELBO

Return type

float

Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. If baselines are present, a baseline loss is also constructed and differentiated.

class JitTraceGraph_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.tracegraph_elbo.TraceGraph_ELBO

Like TraceGraph_ELBO but uses torch.jit.trace() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.

  • Models must not depend on any global data (except the param store).

  • All model inputs that are tensors must be passed in via *args.

  • All model inputs that are not tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs.

loss_and_grads(model, guide, *args, **kwargs)[source]
class BackwardSampleMessenger(enum_trace, guide_trace)[source]

Bases: pyro.poutine.messenger.Messenger

Implements forward filtering / backward sampling for sampling from the joint posterior distribution

class TraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.elbo.ELBO

A trace implementation of ELBO-based SVI that supports - exhaustive enumeration over discrete sample sites, and - local parallel sampling over any sample site in the guide.

To enumerate over a sample site in the guide, mark the site with either infer={'enumerate': 'sequential'} or infer={'enumerate': 'parallel'}. To configure all guide sites at once, use config_enumerate(). To enumerate over a sample site in the model, mark the site infer={'enumerate': 'parallel'} and ensure the site does not appear in the guide.

This assumes restricted dependency structure on the model and guide: variables outside of an plate can never depend on variables inside that plate.

loss(model, guide, *args, **kwargs)[source]
Returns

an estimate of the ELBO

Return type

float

Estimates the ELBO using num_particles many samples (particles).

differentiable_loss(model, guide, *args, **kwargs)[source]
Returns

a differentiable estimate of the ELBO

Return type

torch.Tensor

Raises

ValueError – if the ELBO is not differentiable (e.g. is identically zero)

Estimates a differentiable ELBO using num_particles many samples (particles). The result should be infinitely differentiable (as long as underlying derivatives have been implemented).

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns

an estimate of the ELBO

Return type

float

Estimates the ELBO using num_particles many samples (particles). Performs backward on the ELBO of each particle.

compute_marginals(model, guide, *args, **kwargs)[source]

Computes marginal distributions at each model-enumerated sample site.

Returns

a dict mapping site name to marginal Distribution object

Return type

OrderedDict

sample_posterior(model, guide, *args, **kwargs)[source]

Sample from the joint posterior distribution of all model-enumerated sites given all observations

class JitTraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.traceenum_elbo.TraceEnum_ELBO

Like TraceEnum_ELBO but uses pyro.ops.jit.compile() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.

  • Models must not depend on any global data (except the param store).

  • All model inputs that are tensors must be passed in via *args.

  • All model inputs that are not tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs.

differentiable_loss(model, guide, *args, **kwargs)[source]
loss_and_grads(model, guide, *args, **kwargs)[source]
class TraceMeanField_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.trace_elbo.Trace_ELBO

A trace implementation of ELBO-based SVI. This is currently the only ELBO estimator in Pyro that uses analytic KL divergences when those are available.

In contrast to, e.g., TraceGraph_ELBO and Trace_ELBO this estimator places restrictions on the dependency structure of the model and guide. In particular it assumes that the guide has a mean-field structure, i.e. that it factorizes across the different latent variables present in the guide. It also assumes that all of the latent variables in the guide are reparameterized. This latter condition is satisfied for, e.g., the Normal distribution but is not satisfied for, e.g., the Categorical distribution.

Warning

This estimator may give incorrect results if the mean-field condition is not satisfied.

Note for advanced users:

The mean field condition is a sufficient but not necessary condition for this estimator to be correct. The precise condition is that for every latent variable z in the guide, its parents in the model must not include any latent variables that are descendants of z in the guide. Here ‘parents in the model’ and ‘descendants in the guide’ is with respect to the corresponding (statistical) dependency structure. For example, this condition is always satisfied if the model and guide have identical dependency structures.

loss(model, guide, *args, **kwargs)[source]
Returns

returns an estimate of the ELBO

Return type

float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

class JitTraceMeanField_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.trace_mean_field_elbo.TraceMeanField_ELBO

Like TraceMeanField_ELBO but uses pyro.ops.jit.trace() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.

  • Models must not depend on any global data (except the param store).

  • All model inputs that are tensors must be passed in via *args.

  • All model inputs that are not tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs.

differentiable_loss(model, guide, *args, **kwargs)[source]
loss_and_grads(model, guide, *args, **kwargs)[source]
class TraceTailAdaptive_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.trace_elbo.Trace_ELBO

Interface for Stochastic Variational Inference with an adaptive f-divergence as described in ref. [1]. Users should specify num_particles > 1 and vectorize_particles==True. The argument tail_adaptive_beta can be specified to modify how the adaptive f-divergence is constructed. See reference for details.

Note that this interface does not support computing the varational objective itself; rather it only supports computing gradients of the variational objective. Consequently, one might want to use another SVI interface (e.g. RenyiELBO) in order to monitor convergence.

Note that this interface only supports models in which all the latent variables are fully reparameterized. It also does not support data subsampling.

References [1] “Variational Inference with Tail-adaptive f-Divergence”, Dilin Wang, Hao Liu, Qiang Liu, NeurIPS 2018 https://papers.nips.cc/paper/7816-variational-inference-with-tail-adaptive-f-divergence

loss(model, guide, *args, **kwargs)[source]

It is not necessary to estimate the tail-adaptive f-divergence itself in order to compute the corresponding gradients. Consequently the loss method is left unimplemented.

class RenyiELBO(alpha=0, num_particles=2, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True)[source]

Bases: pyro.infer.elbo.ELBO

An implementation of Renyi’s \(\alpha\)-divergence variational inference following reference [1].

In order for the objective to be a strict lower bound, we require \(\alpha \ge 0\). Note, however, that according to reference [1], depending on the dataset \(\alpha < 0\) might give better results. In the special case \(\alpha = 0\), the objective function is that of the important weighted autoencoder derived in reference [2].

Note

Setting \(\alpha < 1\) gives a better bound than the usual ELBO. For \(\alpha = 1\), it is better to use Trace_ELBO class because it helps reduce variances of gradient estimations.

Parameters
  • alpha (float) – The order of \(\alpha\)-divergence. Here \(\alpha \neq 1\). Default is 0.

  • num_particles – The number of particles/samples used to form the objective (gradient) estimator. Default is 2.

  • max_plate_nesting (int) – Bound on max number of nested pyro.plate() contexts. Default is infinity.

  • strict_enumeration_warning (bool) – Whether to warn about possible misuse of enumeration, i.e. that TraceEnum_ELBO is used iff there are enumerated sample sites.

References:

[1] Renyi Divergence Variational Inference,

Yingzhen Li, Richard E. Turner

[2] Importance Weighted Autoencoders,

Yuri Burda, Roger Grosse, Ruslan Salakhutdinov

loss(model, guide, *args, **kwargs)[source]
Returns

returns an estimate of the ELBO

Return type

float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns

returns an estimate of the ELBO

Return type

float

Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators.

class TraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.elbo.ELBO

A trace-based implementation of Tensor Monte Carlo [1] by way of Tensor Variable Elimination [2] that supports: - local parallel sampling over any sample site in the model or guide - exhaustive enumeration over any sample site in the model or guide

To take multiple samples, mark the site with infer={'enumerate': 'parallel', 'num_samples': N}. To configure all sites in a model or guide at once, use config_enumerate() . To enumerate or sample a sample site in the model, mark the site and ensure the site does not appear in the guide.

This assumes restricted dependency structure on the model and guide: variables outside of an plate can never depend on variables inside that plate .

References

[1] Tensor Monte Carlo: Particle Methods for the GPU Era,

Laurence Aitchison (2018)

[2] Tensor Variable Elimination for Plated Factor Graphs,

Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu, Neeraj Pradhan, Alexander Rush, Noah Goodman (2019)

differentiable_loss(model, guide, *args, **kwargs)[source]
Returns

a differentiable estimate of the marginal log-likelihood

Return type

torch.Tensor

Raises

ValueError – if the ELBO is not differentiable (e.g. is identically zero)

Computes a differentiable TMC estimate using num_particles many samples (particles). The result should be infinitely differentiable (as long as underlying derivatives have been implemented).

loss(model, guide, *args, **kwargs)[source]
loss_and_grads(model, guide, *args, **kwargs)[source]

Importance

class Importance(model, guide=None, num_samples=None)[source]

Bases: pyro.infer.abstract_infer.TracePosterior

Parameters
  • model – probabilistic model defined as a function

  • guide – guide used for sampling defined as a function

  • num_samples – number of samples to draw from the guide (default 10)

This method performs posterior inference by importance sampling using the guide as the proposal distribution. If no guide is provided, it defaults to proposing from the model’s prior.

get_ESS()[source]

Compute (Importance Sampling) Effective Sample Size (ESS).

get_log_normalizer()[source]

Estimator of the normalizing constant of the target distribution. (mean of the unnormalized weights)

get_normalized_weights(log_scale=False)[source]

Compute the normalized importance weights.

psis_diagnostic(model, guide, *args, **kwargs)[source]

Computes the Pareto tail index k for a model/guide pair using the technique described in [1], which builds on previous work in [2]. If \(0 < k < 0.5\) the guide is a good approximation to the model posterior, in the sense described in [1]. If \(0.5 \le k \le 0.7\), the guide provides a suboptimal approximation to the posterior, but may still be useful in practice. If \(k > 0.7\) the guide program provides a poor approximation to the full posterior, and caution should be used when using the guide. Note, however, that a guide may be a poor fit to the full posterior while still yielding reasonable model predictions. If \(k < 0.0\) the importance weights corresponding to the model and guide appear to be bounded from above; this would be a bizarre outcome for a guide trained via ELBO maximization. Please see [1] for a more complete discussion of how the tail index k should be interpreted.

Please be advised that a large number of samples may be required for an accurate estimate of k.

Note that we assume that the model and guide are both vectorized and have static structure. As is canonical in Pyro, the args and kwargs are passed to the model and guide.

References [1] ‘Yes, but Did It Work?: Evaluating Variational Inference.’ Yuling Yao, Aki Vehtari, Daniel Simpson, Andrew Gelman [2] ‘Pareto Smoothed Importance Sampling.’ Aki Vehtari, Andrew Gelman, Jonah Gabry

Parameters
  • model (callable) – the model program.

  • guide (callable) – the guide program.

  • num_particles (int) – the total number of times we run the model and guide in order to compute the diagnostic. defaults to 1000.

  • max_simultaneous_particles – the maximum number of simultaneous samples drawn from the model and guide. defaults to num_particles. num_particles must be divisible by max_simultaneous_particles. compute the diagnostic. defaults to 1000.

  • max_plate_nesting (int) – optional bound on max number of nested pyro.plate() contexts in the model/guide. defaults to 7.

Returns float

the PSIS diagnostic k

vectorized_importance_weights(model, guide, *args, **kwargs)[source]
Parameters
  • model – probabilistic model defined as a function

  • guide – guide used for sampling defined as a function

  • num_samples – number of samples to draw from the guide (default 1)

  • max_plate_nesting (int) – Bound on max number of nested pyro.plate() contexts.

  • normalized (bool) – set to True to return self-normalized importance weights

Returns

returns a (num_samples,)-shaped tensor of importance weights and the model and guide traces that produced them

Vectorized computation of importance weights for models with static structure:

log_weights, model_trace, guide_trace = \
    vectorized_importance_weights(model, guide, *args,
                                  num_samples=1000,
                                  max_plate_nesting=4,
                                  normalized=False)

Reweighted Wake-Sleep

class ReweightedWakeSleep(num_particles=2, insomnia=1.0, model_has_params=True, num_sleep_particles=None, vectorize_particles=True, max_plate_nesting=inf, strict_enumeration_warning=True)[source]

Bases: pyro.infer.elbo.ELBO

An implementation of Reweighted Wake Sleep following reference [1].

Note

Sampling and log_prob evaluation asymptotic complexity:

  1. Using wake-theta and/or wake-phi

    O(num_particles) samples from guide, O(num_particles) log_prob evaluations of model and guide

  2. Using sleep-phi

    O(num_sleep_particles) samples from model, O(num_sleep_particles) log_prob evaluations of guide

if 1) and 2) are combined,

O(num_particles) samples from the guide, O(num_sleep_particles) from the model, O(num_particles + num_sleep_particles) log_prob evaluations of the guide, and O(num_particles) evaluations of the model

Note

This is particularly useful for models with stochastic branching, as described in [2].

Note

This returns _two_ losses, one each for (a) the model parameters (theta), computed using the iwae objective, and (b) the guide parameters (phi), computed using (a combination of) the csis objective and a self-normalized importance-sampled version of the csis objective.

Note

In order to enable computing the sleep-phi terms, the guide program must have its observations explicitly passed in through the keyworded argument observations. Where the value of the observations is unknown during definition, such as for amortized variational inference, it may be given a default argument as observations=None, and the correct value supplied during learning through svi.step(observations=…).

Warning

Mini-batch training is not supported yet.

Parameters
  • num_particles (int) – The number of particles/samples used to form the objective (gradient) estimator. Default is 2.

  • insomnia – The scaling between the wake-phi and sleep-phi terms. Default is 1.0 [wake-phi]

  • model_has_params (bool) – Indicate if model has learnable params. Useful in avoiding extra computation when running in pure sleep mode [csis]. Default is True.

  • num_sleep_particles (int) – The number of particles used to form the sleep-phi estimator. Matches num_particles by default.

  • vectorize_particles (bool) – Whether the traces should be vectorised across num_particles. Default is True.

  • max_plate_nesting (int) – Bound on max number of nested pyro.plate() contexts. Default is infinity.

  • strict_enumeration_warning (bool) – Whether to warn about possible misuse of enumeration, i.e. that TraceEnum_ELBO is used iff there are enumerated sample sites.

References:

[1] Reweighted Wake-Sleep,

Jörg Bornschein, Yoshua Bengio

[2] Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow,

Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood

loss(model, guide, *args, **kwargs)[source]
Returns

returns model loss and guide loss

Return type

float, float

Computes the re-weighted wake-sleep estimators for the model (wake-theta) and the

guide (insomnia * wake-phi + (1 - insomnia) * sleep-phi).

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns

returns model loss and guide loss

Return type

float

Computes the RWS estimators for the model (wake-theta) and the guide (wake-phi). Performs backward as appropriate on both, using num_particle many samples/particles.

Sequential Monte Carlo

exception SMCFailed[source]

Bases: ValueError

Exception raised when SMCFilter fails to find any hypothesis with nonzero probability.

class SMCFilter(model, guide, num_particles, max_plate_nesting, *, ess_threshold=0.5)[source]

Bases: object

SMCFilter is the top-level interface for filtering via sequential monte carlo.

The model and guide should be objects with two methods: .init(state, ...) and .step(state, ...), intended to be called first with init() , then with step() repeatedly. These two methods should have the same signature as SMCFilter ‘s init() and step() of this class, but with an extra first argument state that should be used to store all tensors that depend on sampled variables. The state will be a dict-like object, SMCState , with arbitrary keys and torch.Tensor values. Models can read and write state but guides can only read from it.

Inference complexity is O(len(state) * num_time_steps), so to avoid quadratic complexity in Markov models, ensure that state has fixed size.

Parameters
  • model (object) – probabilistic model with init and step methods

  • guide (object) – guide used for sampling, with init and step methods

  • num_particles (int) – The number of particles used to form the distribution.

  • max_plate_nesting (int) – Bound on max number of nested pyro.plate() contexts.

  • ess_threshold (float) – Effective sample size threshold for deciding when to importance resample: resampling occurs when ess < ess_threshold * num_particles.

get_empirical()[source]
Returns

a marginal distribution over all state tensors.

Return type

a dictionary with keys which are latent variables and values which are Empirical objects.

init(*args, **kwargs)[source]

Perform any initialization for sequential importance resampling. Any args or kwargs are passed to the model and guide

step(*args, **kwargs)[source]

Take a filtering step using sequential importance resampling updating the particle weights and values while resampling if desired. Any args or kwargs are passed to the model and guide

class SMCState(num_particles)[source]

Bases: dict

Dictionary-like object to hold a vectorized collection of tensors to represent all state during inference with SMCFilter. During inference, the SMCFilter resample these tensors.

Keys may have arbitrary hashable type. Values must be torch.Tensor s.

Parameters

num_particles (int) –

Stein Methods

class IMQSteinKernel(alpha=0.5, beta=- 0.5, bandwidth_factor=None)[source]

Bases: pyro.infer.svgd.SteinKernel

An IMQ (inverse multi-quadratic) kernel for use in the SVGD inference algorithm [1]. The bandwidth of the kernel is chosen from the particles using a simple heuristic as in reference [2]. The kernel takes the form

\(K(x, y) = (\alpha + ||x-y||^2/h)^{\beta}\)

where \(\alpha\) and \(\beta\) are user-specified parameters and \(h\) is the bandwidth.

Parameters
  • alpha (float) – Kernel hyperparameter, defaults to 0.5.

  • beta (float) – Kernel hyperparameter, defaults to -0.5.

  • bandwidth_factor (float) – Optional factor by which to scale the bandwidth, defaults to 1.0.

Variables

bandwidth_factor (float) – Property that controls the factor by which to scale the bandwidth at each iteration.

References

[1] “Stein Points,” Wilson Ye Chen, Lester Mackey, Jackson Gorham, Francois-Xavier Briol, Chris. J. Oates. [2] “Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm,” Qiang Liu, Dilin Wang

property bandwidth_factor
log_kernel_and_grad(particles)[source]

See pyro.infer.svgd.SteinKernel.log_kernel_and_grad()

class RBFSteinKernel(bandwidth_factor=None)[source]

Bases: pyro.infer.svgd.SteinKernel

A RBF kernel for use in the SVGD inference algorithm. The bandwidth of the kernel is chosen from the particles using a simple heuristic as in reference [1].

Parameters

bandwidth_factor (float) – Optional factor by which to scale the bandwidth, defaults to 1.0.

Variables

bandwidth_factor (float) – Property that controls the factor by which to scale the bandwidth at each iteration.

References

[1] “Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm,”

Qiang Liu, Dilin Wang

property bandwidth_factor
log_kernel_and_grad(particles)[source]

See pyro.infer.svgd.SteinKernel.log_kernel_and_grad()

class SVGD(model, kernel, optim, num_particles, max_plate_nesting, mode='univariate')[source]

Bases: object

A basic implementation of Stein Variational Gradient Descent as described in reference [1].

Parameters
  • model – The model (callable containing Pyro primitives). Model must be fully vectorized and may only contain continuous latent variables.

  • kernel – a SVGD compatible kernel like RBFSteinKernel.

  • optim (pyro.optim.PyroOptim) – A wrapper for a PyTorch optimizer.

  • num_particles (int) – The number of particles used in SVGD.

  • max_plate_nesting (int) – The max number of nested pyro.plate() contexts in the model.

  • mode (str) – Whether to use a Kernelized Stein Discrepancy that makes use of multivariate test functions (as in [1]) or univariate test functions (as in [2]). Defaults to univariate.

Example usage:

from pyro.infer import SVGD, RBFSteinKernel
from pyro.optim import Adam

kernel = RBFSteinKernel()
adam = Adam({"lr": 0.1})
svgd = SVGD(model, kernel, adam, num_particles=50, max_plate_nesting=0)

for step in range(500):
    svgd.step(model_arg1, model_arg2)

final_particles = svgd.get_named_particles()

References

[1] “Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm,”

Qiang Liu, Dilin Wang

[2] “Kernelized Complete Conditional Stein Discrepancy,”

Raghav Singhal, Saad Lahlou, Rajesh Ranganath

get_named_particles()[source]

Create a dictionary mapping name to vectorized value, of the form {name: tensor}. The leading dimension of each tensor corresponds to particles, i.e. this creates a struct of arrays.

step(*args, **kwargs)[source]

Computes the SVGD gradient, passing args and kwargs to the model, and takes a gradient step.

Return dict

A dictionary of the form {name: float}, where each float is a mean squared gradient. This can be used to monitor the convergence of SVGD.

class SteinKernel[source]

Bases: object

Abstract class for kernels used in the SVGD inference algorithm.

abstract log_kernel_and_grad(particles)[source]

Compute the component kernels and their gradients.

Parameters

particles – a tensor with shape (N, D)

Returns

A pair (log_kernel, kernel_grad) where log_kernel is a (N, N, D)-shaped tensor equal to the logarithm of the kernel and kernel_grad is a (N, N, D)-shaped tensor where the entry (n, m, d) represents the derivative of log_kernel w.r.t. x_{m,d}, where x_{m,d} is the d^th dimension of particle m.

vectorize(fn, num_particles, max_plate_nesting)[source]

Likelihood free methods

class EnergyDistance(beta=1.0, prior_scale=0.0, num_particles=2, max_plate_nesting=inf)[source]

Bases: object

Posterior predictive energy distance [1,2] with optional Bayesian regularization by the prior.

Let p(x,z)=p(z) p(x|z) be the model, q(z|x) be the guide. Then given data x and drawing an iid pair of samples \((Z,X)\) and \((Z',X')\) (where Z is latent and X is the posterior predictive),

\[\begin{split}& Z \sim q(z|x); \quad X \sim p(x|Z) \\ & Z' \sim q(z|x); \quad X' \sim p(x|Z') \\ & loss = \mathbb E_X \|X-x\|^\beta - \frac 1 2 \mathbb E_{X,X'}\|X-X'\|^\beta - \lambda \mathbb E_Z \log p(Z)\end{split}\]

This is a likelihood-free inference algorithm, and can be used for likelihoods without tractable density functions. The \(\beta\) energy distance is a robust loss functions, and is well defined for any distribution with finite fractional moment \(\mathbb E[\|X\|^\beta]\).

This requires static model structure, a fully reparametrized guide, and reparametrized likelihood distributions in the model. Model latent distributions may be non-reparametrized.

References

[1] Gabor J. Szekely, Maria L. Rizzo (2003)

Energy Statistics: A Class of Statistics Based on Distances.

[2] Tilmann Gneiting, Adrian E. Raftery (2007)

Strictly Proper Scoring Rules, Prediction, and Estimation. https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf

Parameters
  • beta (float) – Exponent \(\beta\) from [1,2]. The loss function is strictly proper for distributions with finite \(beta\)-absolute moment \(E[\|X\|^\beta]\). Thus for heavy tailed distributions beta should be small, e.g. for Cauchy distributions, \(\beta<1\) is strictly proper. Defaults to 1. Must be in the open interval (0,2).

  • prior_scale (float) – Nonnegative scale for prior regularization. Model parameters are trained only if this is positive. If zero (default), then model log densities will not be computed (guide log densities are never computed).

  • num_particles (int) – The number of particles/samples used to form the gradient estimators. Must be at least 2.

  • max_plate_nesting (int) – Optional bound on max number of nested pyro.plate() contexts. If omitted, this will guess a valid value by running the (model,guide) pair once.

__call__(model, guide, *args, **kwargs)[source]

Computes the surrogate loss that can be differentiated with autograd to produce gradient estimates for the model and guide parameters.

loss(*args, **kwargs)[source]

Not implemented. Added for compatibility with unit tests only.

Discrete Inference

infer_discrete(fn=None, first_available_dim=None, temperature=1, *, strict_enumeration_warning=True)[source]

A poutine that samples discrete sites marked with site["infer"]["enumerate"] = "parallel" from the posterior, conditioned on observations.

Example:

@infer_discrete(first_available_dim=-1, temperature=0)
@config_enumerate
def viterbi_decoder(data, hidden_dim=10):
    transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim)
    means = torch.arange(float(hidden_dim))
    states = [0]
    for t in pyro.markov(range(len(data))):
        states.append(pyro.sample("states_{}".format(t),
                                  dist.Categorical(transition[states[-1]])))
        pyro.sample("obs_{}".format(t),
                    dist.Normal(means[states[-1]], 1.),
                    obs=data[t])
    return states  # returns maximum likelihood states
Parameters
  • fn – a stochastic function (callable containing Pyro primitive calls)

  • first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer.

  • temperature (int) – Either 1 (sample via forward-filter backward-sample) or 0 (optimize via Viterbi-like MAP inference). Defaults to 1 (sample).

  • strict_enumeration_warning (bool) – Whether to warn in case no enumerated sample sites are found. Defalts to True.

class TraceEnumSample_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.traceenum_elbo.TraceEnum_ELBO

This extends TraceEnum_ELBO to make it cheaper to sample from discrete latent states during SVI.

The following are equivalent but the first is cheaper, sharing work between the computations of loss and z:

# Version 1.
elbo = TraceEnumSample_ELBO(max_plate_nesting=1)
loss = elbo.loss(*args, **kwargs)
z = elbo.sample_saved()

# Version 2.
elbo = TraceEnum_ELBO(max_plate_nesting=1)
loss = elbo.loss(*args, **kwargs)
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
z = infer_discrete(poutine.replay(model, guide_trace),
                   first_available_dim=-2)(*args, **kwargs)
sample_saved()[source]

Generate latent samples while reusing work from SVI.step().

Prediction utilities

class Predictive(model, posterior_samples=None, guide=None, num_samples=None, return_sites=(), parallel=False)[source]

Bases: torch.nn.modules.module.Module

EXPERIMENTAL class used to construct predictive distribution. The predictive distribution is obtained by running the model conditioned on latent samples from posterior_samples. If a guide is provided, then posterior samples from all the latent sites are also returned.

Warning

The interface for the Predictive class is experimental, and might change in the future.

Parameters
  • model – Python callable containing Pyro primitives.

  • posterior_samples (dict) – dictionary of samples from the posterior.

  • guide (callable) – optional guide to get posterior samples of sites not present in posterior_samples.

  • num_samples (int) – number of samples to draw from the predictive distribution. This argument has no effect if posterior_samples is non-empty, in which case, the leading dimension size of samples in posterior_samples is used.

  • return_sites (list, tuple, or set) – sites to return; by default only sample sites not present in posterior_samples are returned.

  • parallel (bool) – predict in parallel by wrapping the existing model in an outermost plate messenger. Note that this requires that the model has all batch dims correctly annotated via plate. Default is False.

call(*args, **kwargs)[source]

Method that calls forward() and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlike forward(), this method can be traced by torch.jit.trace_module().

Warning

This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue.

forward(*args, **kwargs)[source]

Returns dict of samples from the predictive distribution. By default, only sample sites not contained in posterior_samples are returned. This can be modified by changing the return_sites keyword argument of this Predictive instance.

Note

This method is used internally by Module. Users should instead use __call__() as in Predictive(model)(*args, **kwargs).

Parameters
  • args – model arguments.

  • kwargs – model keyword arguments.

get_samples(*args, **kwargs)[source]
get_vectorized_trace(*args, **kwargs)[source]

Returns a single vectorized trace from the predictive distribution. Note that this requires that the model has all batch dims correctly annotated via plate.

Parameters
  • args – model arguments.

  • kwargs – model keyword arguments.

training: bool
class EmpiricalMarginal(trace_posterior, sites=None, validate_args=None)[source]

Bases: pyro.distributions.distribution.Distribution, Callable

Marginal distribution over a single site (or multiple, provided they have the same shape) from the TracePosterior’s model.

Note

If multiple sites are specified, they must have the same tensor shape. Samples from each site will be stacked and stored within a single tensor. See Empirical. To hold the marginal distribution of sites having different shapes, use Marginals instead.

Parameters
  • trace_posterior (TracePosterior) – a TracePosterior instance representing a Monte Carlo posterior.

  • sites (list) – optional list of sites for which we need to generate the marginal distribution.

class Marginals(trace_posterior, sites=None, validate_args=None)[source]

Bases: object

Holds the marginal distribution over one or more sites from the TracePosterior’s model. This is a convenience container class, which can be extended by TracePosterior subclasses. e.g. for implementing diagnostics.

Parameters
  • trace_posterior (TracePosterior) – a TracePosterior instance representing a Monte Carlo posterior.

  • sites (list) – optional list of sites for which we need to generate the marginal distribution.

property empirical

A dictionary of sites’ names and their corresponding EmpiricalMarginal distribution.

Type

OrderedDict

support(flatten=False)[source]

Gets support of this marginal distribution.

Parameters

flatten (bool) – A flag to decide if we want to flatten batch_shape when the marginal distribution is collected from the posterior with num_chains > 1. Defaults to False.

Returns

a dict with keys are sites’ names and values are sites’ supports.

Return type

OrderedDict

class TracePosterior(num_chains=1)[source]

Bases: object

Abstract TracePosterior object from which posterior inference algorithms inherit. When run, collects a bag of execution traces from the approximate posterior. This is designed to be used by other utility classes like EmpiricalMarginal, that need access to the collected execution traces.

information_criterion(pointwise=False)[source]

Computes information criterion of the model. Currently, returns only “Widely Applicable/Watanabe-Akaike Information Criterion” (WAIC) and the corresponding effective number of parameters.

Reference:

[1] Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC, Aki Vehtari, Andrew Gelman, and Jonah Gabry

Parameters

pointwise (bool) – a flag to decide if we want to get a vectorized WAIC or not. When pointwise=False, returns the sum.

Returns

a dictionary containing values of WAIC and its effective number of parameters.

Return type

OrderedDict

marginal(sites=None)[source]

Generates the marginal distribution of this posterior.

Parameters

sites (list) – optional list of sites for which we need to generate the marginal distribution.

Returns

A Marginals class instance.

Return type

Marginals

run(*args, **kwargs)[source]

Calls self._traces to populate execution traces from a stochastic Pyro model.

Parameters
  • args – optional args taken by self._traces.

  • kwargs – optional keywords args taken by self._traces.

class TracePredictive(model, posterior, num_samples, keep_sites=None)[source]

Bases: pyro.infer.abstract_infer.TracePosterior

Warning

This class is deprecated and will be removed in a future release. Use the Predictive class instead.

Generates and holds traces from the posterior predictive distribution, given model execution traces from the approximate posterior. This is achieved by constraining latent sites to randomly sampled parameter values from the model execution traces and running the model forward to generate traces with new response (“_RETURN”) sites. :param model: arbitrary Python callable containing Pyro primitives. :param TracePosterior posterior: trace posterior instance holding samples from the model’s approximate posterior. :param int num_samples: number of samples to generate. :param keep_sites: The sites which should be sampled from posterior distribution (default: all)

marginal(sites=None)[source]

Gets marginal distribution for this predictive posterior distribution.