SVI

class SVI(model, guide, optim, loss, loss_and_grads=None, num_samples=0, num_steps=0, **kwargs)[source]

Bases: pyro.infer.abstract_infer.TracePosterior

Parameters:
  • model – the model (callable containing Pyro primitives)
  • guide – the guide (callable containing Pyro primitives)
  • optim (PyroOptim) – a wrapper a for a PyTorch optimizer
  • loss (pyro.infer.elbo.ELBO) – an instance of a subclass of ELBO. Pyro provides three built-in losses: Trace_ELBO, TraceGraph_ELBO, and TraceEnum_ELBO. See the ELBO docs to learn how to implement a custom loss.
  • num_samples – (DEPRECATED) the number of samples for Monte Carlo posterior approximation
  • num_steps – (DEPRECATED) the number of optimization steps to take in run()

A unified interface for stochastic variational inference in Pyro. The most commonly used loss is loss=Trace_ELBO(). See the tutorial SVI Part I for a discussion.

evaluate_loss(*args, **kwargs)[source]
Returns:estimate of the loss
Return type:float

Evaluate the loss function. Any args or kwargs are passed to the model and guide.

run(*args, **kwargs)[source]

Warning

This method is deprecated, and will be removed in a future release. For inference, use step() directly, and for predictions, use the Predictive class.

step(*args, **kwargs)[source]
Returns:estimate of the loss
Return type:float

Take a gradient step on the loss function (and any auxiliary loss functions generated under the hood by loss_and_grads). Any args or kwargs are passed to the model and guide

ELBO

class ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]

Bases: object

ELBO is the top-level interface for stochastic variational inference via optimization of the evidence lower bound.

Most users will not interact with this base class ELBO directly; instead they will create instances of derived classes: Trace_ELBO, TraceGraph_ELBO, or TraceEnum_ELBO.

Parameters:
  • num_particles – The number of particles/samples used to form the ELBO (gradient) estimators.
  • max_plate_nesting (int) – Optional bound on max number of nested pyro.plate() contexts. This is only required when enumerating over sample sites in parallel, e.g. if a site sets infer={"enumerate": "parallel"}. If omitted, ELBO may guess a valid value by running the (model,guide) pair once, however this guess may be incorrect if model or guide structure is dynamic.
  • vectorize_particles (bool) – Whether to vectorize the ELBO computation over num_particles. Defaults to False. This requires static structure in model and guide.
  • strict_enumeration_warning (bool) – Whether to warn about possible misuse of enumeration, i.e. that pyro.infer.traceenum_elbo.TraceEnum_ELBO is used iff there are enumerated sample sites.
  • ignore_jit_warnings (bool) – Flag to ignore warnings from the JIT tracer. When this is True, all torch.jit.TracerWarning will be ignored. Defaults to False.
  • jit_options (bool) – Optional dict of options to pass to torch.jit.trace() , e.g. {"check_trace": True}.
  • retain_graph (bool) – Whether to retain autograd graph during an SVI step. Defaults to None (False).
  • tail_adaptive_beta (float) – Exponent beta with -1.0 <= beta < 0.0 for use with TraceTailAdaptive_ELBO.

References

[1] Automated Variational Inference in Probabilistic Programming David Wingate, Theo Weber

[2] Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei

class Trace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]

Bases: pyro.infer.elbo.ELBO

A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and [2]. There are no restrictions on the dependency structure of the model or the guide. The gradient estimator includes partial Rao-Blackwellization for reducing the variance of the estimator when non-reparameterizable random variables are present. The Rao-Blackwellization is partial in that it only uses conditional independence information that is marked by plate contexts. For more fine-grained Rao-Blackwellization, see TraceGraph_ELBO.

References

[1] Automated Variational Inference in Probabilistic Programming,
David Wingate, Theo Weber
[2] Black Box Variational Inference,
Rajesh Ranganath, Sean Gerrish, David M. Blei
loss(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

differentiable_loss(model, guide, *args, **kwargs)[source]

Computes the surrogate loss that can be differentiated with autograd to produce gradient estimates for the model and guide parameters

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators.

class JitTrace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]

Bases: pyro.infer.trace_elbo.Trace_ELBO

Like Trace_ELBO but uses pyro.ops.jit.compile() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.
  • Models must not depend on any global data (except the param store).
  • All model inputs that are tensors must be passed in via *args.
  • All model inputs that are not tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs.
loss_and_surrogate_loss(model, guide, *args, **kwargs)[source]
differentiable_loss(model, guide, *args, **kwargs)[source]
loss_and_grads(model, guide, *args, **kwargs)[source]
class TraceGraph_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]

Bases: pyro.infer.elbo.ELBO

A TraceGraph implementation of ELBO-based SVI. The gradient estimator is constructed along the lines of reference [1] specialized to the case of the ELBO. It supports arbitrary dependency structure for the model and guide as well as baselines for non-reparameterizable random variables. Where possible, conditional dependency information as recorded in the Trace is used to reduce the variance of the gradient estimator. In particular two kinds of conditional dependency information are used to reduce variance:

  • the sequential order of samples (z is sampled after y => y does not depend on z)
  • plate generators

References

[1] Gradient Estimation Using Stochastic Computation Graphs,
John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel
[2] Neural Variational Inference and Learning in Belief Networks
Andriy Mnih, Karol Gregor
loss(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. If baselines are present, a baseline loss is also constructed and differentiated.

class JitTraceGraph_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]

Bases: pyro.infer.tracegraph_elbo.TraceGraph_ELBO

Like TraceGraph_ELBO but uses torch.jit.trace() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.
  • Models must not depend on any global data (except the param store).
  • All model inputs that are tensors must be passed in via *args.
  • All model inputs that are not tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs.
loss_and_grads(model, guide, *args, **kwargs)[source]
class BackwardSampleMessenger(enum_trace, guide_trace)[source]

Bases: pyro.poutine.messenger.Messenger

Implements forward filtering / backward sampling for sampling from the joint posterior distribution

class TraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]

Bases: pyro.infer.elbo.ELBO

A trace implementation of ELBO-based SVI that supports - exhaustive enumeration over discrete sample sites, and - local parallel sampling over any sample site in the guide.

To enumerate over a sample site in the guide, mark the site with either infer={'enumerate': 'sequential'} or infer={'enumerate': 'parallel'}. To configure all guide sites at once, use config_enumerate(). To enumerate over a sample site in the model, mark the site infer={'enumerate': 'parallel'} and ensure the site does not appear in the guide.

This assumes restricted dependency structure on the model and guide: variables outside of an plate can never depend on variables inside that plate.

loss(model, guide, *args, **kwargs)[source]
Returns:an estimate of the ELBO
Return type:float

Estimates the ELBO using num_particles many samples (particles).

differentiable_loss(model, guide, *args, **kwargs)[source]
Returns:a differentiable estimate of the ELBO
Return type:torch.Tensor
Raises:ValueError – if the ELBO is not differentiable (e.g. is identically zero)

Estimates a differentiable ELBO using num_particles many samples (particles). The result should be infinitely differentiable (as long as underlying derivatives have been implemented).

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns:an estimate of the ELBO
Return type:float

Estimates the ELBO using num_particles many samples (particles). Performs backward on the ELBO of each particle.

compute_marginals(model, guide, *args, **kwargs)[source]

Computes marginal distributions at each model-enumerated sample site.

Returns:a dict mapping site name to marginal Distribution object
Return type:OrderedDict
sample_posterior(model, guide, *args, **kwargs)[source]

Sample from the joint posterior distribution of all model-enumerated sites given all observations

class JitTraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]

Bases: pyro.infer.traceenum_elbo.TraceEnum_ELBO

Like TraceEnum_ELBO but uses pyro.ops.jit.compile() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.
  • Models must not depend on any global data (except the param store).
  • All model inputs that are tensors must be passed in via *args.
  • All model inputs that are not tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs.
differentiable_loss(model, guide, *args, **kwargs)[source]
loss_and_grads(model, guide, *args, **kwargs)[source]
class TraceMeanField_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]

Bases: pyro.infer.trace_elbo.Trace_ELBO

A trace implementation of ELBO-based SVI. This is currently the only ELBO estimator in Pyro that uses analytic KL divergences when those are available.

In contrast to, e.g., TraceGraph_ELBO and Trace_ELBO this estimator places restrictions on the dependency structure of the model and guide. In particular it assumes that the guide has a mean-field structure, i.e. that it factorizes across the different latent variables present in the guide. It also assumes that all of the latent variables in the guide are reparameterized. This latter condition is satisfied for, e.g., the Normal distribution but is not satisfied for, e.g., the Categorical distribution.

Warning

This estimator may give incorrect results if the mean-field condition is not satisfied.

Note for advanced users:

The mean field condition is a sufficient but not necessary condition for this estimator to be correct. The precise condition is that for every latent variable z in the guide, its parents in the model must not include any latent variables that are descendants of z in the guide. Here ‘parents in the model’ and ‘descendants in the guide’ is with respect to the corresponding (statistical) dependency structure. For example, this condition is always satisfied if the model and guide have identical dependency structures.

loss(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

class JitTraceMeanField_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]

Bases: pyro.infer.trace_mean_field_elbo.TraceMeanField_ELBO

Like TraceMeanField_ELBO but uses pyro.ops.jit.trace() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.
  • Models must not depend on any global data (except the param store).
  • All model inputs that are tensors must be passed in via *args.
  • All model inputs that are not tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs.
differentiable_loss(model, guide, *args, **kwargs)[source]
loss_and_grads(model, guide, *args, **kwargs)[source]
class TraceTailAdaptive_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]

Bases: pyro.infer.trace_elbo.Trace_ELBO

Interface for Stochastic Variational Inference with an adaptive f-divergence as described in ref. [1]. Users should specify num_particles > 1 and vectorize_particles==True. The argument tail_adaptive_beta can be specified to modify how the adaptive f-divergence is constructed. See reference for details.

Note that this interface does not support computing the varational objective itself; rather it only supports computing gradients of the variational objective. Consequently, one might want to use another SVI interface (e.g. RenyiELBO) in order to monitor convergence.

Note that this interface only supports models in which all the latent variables are fully reparameterized. It also does not support data subsampling.

References [1] “Variational Inference with Tail-adaptive f-Divergence”, Dilin Wang, Hao Liu, Qiang Liu, NeurIPS 2018 https://papers.nips.cc/paper/7816-variational-inference-with-tail-adaptive-f-divergence

loss(model, guide, *args, **kwargs)[source]

It is not necessary to estimate the tail-adaptive f-divergence itself in order to compute the corresponding gradients. Consequently the loss method is left unimplemented.

class RenyiELBO(alpha=0, num_particles=2, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True)[source]

Bases: pyro.infer.elbo.ELBO

An implementation of Renyi’s \(\alpha\)-divergence variational inference following reference [1].

In order for the objective to be a strict lower bound, we require \(\alpha \ge 0\). Note, however, that according to reference [1], depending on the dataset \(\alpha < 0\) might give better results. In the special case \(\alpha = 0\), the objective function is that of the important weighted autoencoder derived in reference [2].

Note

Setting \(\alpha < 1\) gives a better bound than the usual ELBO. For \(\alpha = 1\), it is better to use Trace_ELBO class because it helps reduce variances of gradient estimations.

Parameters:
  • alpha (float) – The order of \(\alpha\)-divergence. Here \(\alpha \neq 1\). Default is 0.
  • num_particles – The number of particles/samples used to form the objective (gradient) estimator. Default is 2.
  • max_plate_nesting (int) – Bound on max number of nested pyro.plate() contexts. Default is infinity.
  • strict_enumeration_warning (bool) – Whether to warn about possible misuse of enumeration, i.e. that TraceEnum_ELBO is used iff there are enumerated sample sites.

References:

[1] Renyi Divergence Variational Inference,
Yingzhen Li, Richard E. Turner
[2] Importance Weighted Autoencoders,
Yuri Burda, Roger Grosse, Ruslan Salakhutdinov
loss(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators.

class TraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]

Bases: pyro.infer.elbo.ELBO

A trace-based implementation of Tensor Monte Carlo [1] by way of Tensor Variable Elimination [2] that supports: - local parallel sampling over any sample site in the model or guide - exhaustive enumeration over any sample site in the model or guide

To take multiple samples, mark the site with infer={'enumerate': 'parallel', 'num_samples': N}. To configure all sites in a model or guide at once, use config_enumerate() . To enumerate or sample a sample site in the model, mark the site and ensure the site does not appear in the guide.

This assumes restricted dependency structure on the model and guide: variables outside of an plate can never depend on variables inside that plate .

References

[1] Tensor Monte Carlo: Particle Methods for the GPU Era,
Laurence Aitchison (2018)
[2] Tensor Variable Elimination for Plated Factor Graphs,
Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu, Neeraj Pradhan, Alexander Rush, Noah Goodman (2019)
differentiable_loss(model, guide, *args, **kwargs)[source]
Returns:a differentiable estimate of the marginal log-likelihood
Return type:torch.Tensor
Raises:ValueError – if the ELBO is not differentiable (e.g. is identically zero)

Computes a differentiable TMC estimate using num_particles many samples (particles). The result should be infinitely differentiable (as long as underlying derivatives have been implemented).

loss(model, guide, *args, **kwargs)[source]
loss_and_grads(model, guide, *args, **kwargs)[source]

Importance

class Importance(model, guide=None, num_samples=None)[source]

Bases: pyro.infer.abstract_infer.TracePosterior

Parameters:
  • model – probabilistic model defined as a function
  • guide – guide used for sampling defined as a function
  • num_samples – number of samples to draw from the guide (default 10)

This method performs posterior inference by importance sampling using the guide as the proposal distribution. If no guide is provided, it defaults to proposing from the model’s prior.

get_ESS()[source]

Compute (Importance Sampling) Effective Sample Size (ESS).

get_log_normalizer()[source]

Estimator of the normalizing constant of the target distribution. (mean of the unnormalized weights)

get_normalized_weights(log_scale=False)[source]

Compute the normalized importance weights.

psis_diagnostic(model, guide, *args, **kwargs)[source]

Computes the Pareto tail index k for a model/guide pair using the technique described in [1], which builds on previous work in [2]. If \(0 < k < 0.5\) the guide is a good approximation to the model posterior, in the sense described in [1]. If \(0.5 \le k \le 0.7\), the guide provides a suboptimal approximation to the posterior, but may still be useful in practice. If \(k > 0.7\) the guide program provides a poor approximation to the full posterior, and caution should be used when using the guide. Note, however, that a guide may be a poor fit to the full posterior while still yielding reasonable model predictions. If \(k < 0.0\) the importance weights corresponding to the model and guide appear to be bounded from above; this would be a bizarre outcome for a guide trained via ELBO maximization. Please see [1] for a more complete discussion of how the tail index k should be interpreted.

Please be advised that a large number of samples may be required for an accurate estimate of k.

Note that we assume that the model and guide are both vectorized and have static structure. As is canonical in Pyro, the args and kwargs are passed to the model and guide.

References [1] ‘Yes, but Did It Work?: Evaluating Variational Inference.’ Yuling Yao, Aki Vehtari, Daniel Simpson, Andrew Gelman [2] ‘Pareto Smoothed Importance Sampling.’ Aki Vehtari, Andrew Gelman, Jonah Gabry

Parameters:
  • model (callable) – the model program.
  • guide (callable) – the guide program.
  • num_particles (int) – the total number of times we run the model and guide in order to compute the diagnostic. defaults to 1000.
  • max_simultaneous_particles – the maximum number of simultaneous samples drawn from the model and guide. defaults to num_particles. num_particles must be divisible by max_simultaneous_particles. compute the diagnostic. defaults to 1000.
  • max_plate_nesting (int) – optional bound on max number of nested pyro.plate() contexts in the model/guide. defaults to 7.
Returns float:

the PSIS diagnostic k

vectorized_importance_weights(model, guide, *args, **kwargs)[source]
Parameters:
  • model – probabilistic model defined as a function
  • guide – guide used for sampling defined as a function
  • num_samples – number of samples to draw from the guide (default 1)
  • max_plate_nesting (int) – Bound on max number of nested pyro.plate() contexts.
  • normalized (bool) – set to True to return self-normalized importance weights
Returns:

returns a (num_samples,)-shaped tensor of importance weights and the model and guide traces that produced them

Vectorized computation of importance weights for models with static structure:

log_weights, model_trace, guide_trace = \
    vectorized_importance_weights(model, guide, *args,
                                  num_samples=1000,
                                  max_plate_nesting=4,
                                  normalized=False)

Reweighted Wake-Sleep

class ReweightedWakeSleep(num_particles=2, insomnia=1.0, model_has_params=True, num_sleep_particles=None, vectorize_particles=True, max_plate_nesting=inf, strict_enumeration_warning=True)[source]

Bases: pyro.infer.elbo.ELBO

An implementation of Reweighted Wake Sleep following reference [1].

Note

Sampling and log_prob evaluation asymptotic complexity:

  1. Using wake-theta and/or wake-phi
    O(num_particles) samples from guide, O(num_particles) log_prob evaluations of model and guide
  2. Using sleep-phi
    O(num_sleep_particles) samples from model, O(num_sleep_particles) log_prob evaluations of guide
if 1) and 2) are combined,
O(num_particles) samples from the guide, O(num_sleep_particles) from the model, O(num_particles + num_sleep_particles) log_prob evaluations of the guide, and O(num_particles) evaluations of the model

Note

This is particularly useful for models with stochastic branching, as described in [2].

Note

This returns _two_ losses, one each for (a) the model parameters (theta), computed using the iwae objective, and (b) the guide parameters (phi), computed using (a combination of) the csis objective and a self-normalized importance-sampled version of the csis objective.

Note

In order to enable computing the sleep-phi terms, the guide program must have its observations explicitly passed in through the keyworded argument observations. Where the value of the observations is unknown during definition, such as for amortized variational inference, it may be given a default argument as observations=None, and the correct value supplied during learning through svi.step(observations=…).

Warning

Mini-batch training is not supported yet.

Parameters:
  • num_particles (int) – The number of particles/samples used to form the objective (gradient) estimator. Default is 2.
  • insomnia – The scaling between the wake-phi and sleep-phi terms. Default is 1.0 [wake-phi]
  • model_has_params (bool) – Indicate if model has learnable params. Useful in avoiding extra computation when running in pure sleep mode [csis]. Default is True.
  • num_sleep_particles (int) – The number of particles used to form the sleep-phi estimator. Matches num_particles by default.
  • vectorize_particles (bool) – Whether the traces should be vectorised across num_particles. Default is True.
  • max_plate_nesting (int) – Bound on max number of nested pyro.plate() contexts. Default is infinity.
  • strict_enumeration_warning (bool) – Whether to warn about possible misuse of enumeration, i.e. that TraceEnum_ELBO is used iff there are enumerated sample sites.

References:

[1] Reweighted Wake-Sleep,
Jörg Bornschein, Yoshua Bengio
[2] Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow,
Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood
loss(model, guide, *args, **kwargs)[source]
Returns:returns model loss and guide loss
Return type:float, float
Computes the re-weighted wake-sleep estimators for the model (wake-theta) and the
guide (insomnia * wake-phi + (1 - insomnia) * sleep-phi).
loss_and_grads(model, guide, *args, **kwargs)[source]
Returns:returns model loss and guide loss
Return type:float

Computes the RWS estimators for the model (wake-theta) and the guide (wake-phi). Performs backward as appropriate on both, using num_particle many samples/particles.

Sequential Monte Carlo

class SMCFilter(model, guide, num_particles, max_plate_nesting)[source]

Bases: object

SMCFilter is the top-level interface for filtering via sequential monte carlo.

The model and guide should be objects with two methods: .init(state, ...) and .step(state, ...), intended to be called first with init() , then with step() repeatedly. These two methods should have the same signature as SMCFilter ‘s init() and step() of this class, but with an extra first argument state that should be used to store all tensors that depend on sampled variables. The state will be a dict-like object, SMCState , with arbitrary keys and torch.Tensor values. Models can read and write state but guides can only read from it.

Inference complexity is O(len(state) * num_time_steps), so to avoid quadratic complexity in Markov models, ensure that state has fixed size.

Parameters:
  • model (object) – probabilistic model with init and step methods
  • guide (object) – guide used for sampling, with init and step methods
  • num_particles (int) – The number of particles used to form the distribution.
  • max_plate_nesting (int) – Bound on max number of nested pyro.plate() contexts.
get_empirical()[source]
Returns:a marginal distribution over all state tensors.
Return type:a dictionary with keys which are latent variables and values which are Empirical objects.
init(*args, **kwargs)[source]

Perform any initialization for sequential importance resampling. Any args or kwargs are passed to the model and guide

step(*args, **kwargs)[source]

Take a filtering step using sequential importance resampling updating the particle weights and values while resampling if desired. Any args or kwargs are passed to the model and guide

class SMCState(num_particles)[source]

Bases: dict

Dictionary-like object to hold a vectorized collection of tensors to represent all state during inference with SMCFilter. During inference, the SMCFilter resample these tensors.

Keys may have arbitrary hashable type. Values must be torch.Tensor s.

Parameters:num_particles (int) –

Stein Methods

class IMQSteinKernel(alpha=0.5, beta=-0.5, bandwidth_factor=None)[source]

Bases: pyro.infer.svgd.SteinKernel

An IMQ (inverse multi-quadratic) kernel for use in the SVGD inference algorithm [1]. The bandwidth of the kernel is chosen from the particles using a simple heuristic as in reference [2]. The kernel takes the form

\(K(x, y) = (\alpha + ||x-y||^2/h)^{\beta}\)

where \(\alpha\) and \(\beta\) are user-specified parameters and \(h\) is the bandwidth.

Parameters:
  • alpha (float) – Kernel hyperparameter, defaults to 0.5.
  • beta (float) – Kernel hyperparameter, defaults to -0.5.
  • bandwidth_factor (float) – Optional factor by which to scale the bandwidth, defaults to 1.0.
Variables:

bandwidth_factor (float) – Property that controls the factor by which to scale the bandwidth at each iteration.

References

[1] “Stein Points,” Wilson Ye Chen, Lester Mackey, Jackson Gorham, Francois-Xavier Briol, Chris. J. Oates. [2] “Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm,” Qiang Liu, Dilin Wang

bandwidth_factor
log_kernel_and_grad(particles)[source]

See pyro.infer.svgd.SteinKernel.log_kernel_and_grad()

class RBFSteinKernel(bandwidth_factor=None)[source]

Bases: pyro.infer.svgd.SteinKernel

A RBF kernel for use in the SVGD inference algorithm. The bandwidth of the kernel is chosen from the particles using a simple heuristic as in reference [1].

Parameters:bandwidth_factor (float) – Optional factor by which to scale the bandwidth, defaults to 1.0.
Variables:bandwidth_factor (float) – Property that controls the factor by which to scale the bandwidth at each iteration.

References

[1] “Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm,”
Qiang Liu, Dilin Wang
bandwidth_factor
log_kernel_and_grad(particles)[source]

See pyro.infer.svgd.SteinKernel.log_kernel_and_grad()

class SVGD(model, kernel, optim, num_particles, max_plate_nesting, mode='univariate')[source]

Bases: object

A basic implementation of Stein Variational Gradient Descent as described in reference [1].

Parameters:
  • model – The model (callable containing Pyro primitives). Model must be fully vectorized and may only contain continuous latent variables.
  • kernel – a SVGD compatible kernel like RBFSteinKernel.
  • optim (pyro.optim.PyroOptim) – A wrapper for a PyTorch optimizer.
  • num_particles (int) – The number of particles used in SVGD.
  • max_plate_nesting (int) – The max number of nested pyro.plate() contexts in the model.
  • mode (str) – Whether to use a Kernelized Stein Discrepancy that makes use of multivariate test functions (as in [1]) or univariate test functions (as in [2]). Defaults to univariate.

Example usage:

from pyro.infer import SVGD, RBFSteinKernel
from pyro.optim import Adam

kernel = RBFSteinKernel()
adam = Adam({"lr": 0.1})
svgd = SVGD(model, kernel, adam, num_particles=50, max_plate_nesting=0)

for step in range(500):
    svgd.step(model_arg1, model_arg2)

final_particles = svgd.get_named_particles()

References

[1] “Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm,”
Qiang Liu, Dilin Wang
[2] “Kernelized Complete Conditional Stein Discrepancy,”
Raghav Singhal, Saad Lahlou, Rajesh Ranganath
get_named_particles()[source]

Create a dictionary mapping name to vectorized value, of the form {name: tensor}. The leading dimension of each tensor corresponds to particles, i.e. this creates a struct of arrays.

step(*args, **kwargs)[source]

Computes the SVGD gradient, passing args and kwargs to the model, and takes a gradient step.

Return dict:A dictionary of the form {name: float}, where each float is a mean squared gradient. This can be used to monitor the convergence of SVGD.
class SteinKernel[source]

Bases: object

Abstract class for kernels used in the SVGD inference algorithm.

log_kernel_and_grad(particles)[source]

Compute the component kernels and their gradients.

Parameters:particles – a tensor with shape (N, D)
Returns:A pair (log_kernel, kernel_grad) where log_kernel is a (N, N, D)-shaped tensor equal to the logarithm of the kernel and kernel_grad is a (N, N, D)-shaped tensor where the entry (n, m, d) represents the derivative of log_kernel w.r.t. x_{m,d}, where x_{m,d} is the d^th dimension of particle m.
vectorize(fn, num_particles, max_plate_nesting)[source]

Likelihood free methods

class EnergyDistance(beta=1.0, prior_scale=0.0, num_particles=2, max_plate_nesting=inf)[source]

Bases: object

Posterior predictive energy distance [1,2] with optional Bayesian regularization by the prior.

Let p(x,z)=p(z) p(x|z) be the model, q(z|x) be the guide. Then given data x and drawing an iid pair of samples \((Z,X)\) and \((Z',X')\) (where Z is latent and X is the posterior predictive),

\[\begin{split}& Z \sim q(z|x); \quad X \sim p(x|Z) \\ & Z' \sim q(z|x); \quad X' \sim p(x|Z') \\ & loss = \mathbb E_X \|X-x\|^\beta - \frac 1 2 \mathbb E_{X,X'}\|X-X'\|^\beta - \lambda \mathbb E_Z \log p(Z)\end{split}\]

This is a likelihood-free inference algorithm, and can be used for likelihoods without tractable density functions. The \(\beta\) energy distance is a robust loss functions, and is well defined for any distribution with finite fractional moment \(\mathbb E[\|X\|^\beta]\).

This requires static model structure, a fully reparametrized guide, and reparametrized likelihood distributions in the model. Model latent distributions may be non-reparametrized.

References

[1] Gabor J. Szekely, Maria L. Rizzo (2003)
Energy Statistics: A Class of Statistics Based on Distances.
[2] Tilmann Gneiting, Adrian E. Raftery (2007)
Strictly Proper Scoring Rules, Prediction, and Estimation. https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf
Parameters:
  • beta (float) – Exponent \(\beta\) from [1,2]. The loss function is strictly proper for distributions with finite \(beta\)-absolute moment \(E[\|X\|^\beta]\). Thus for heavy tailed distributions beta should be small, e.g. for Cauchy distributions, \(\beta<1\) is strictly proper. Defaults to 1. Must be in the open interval (0,2).
  • prior_scale (float) – Nonnegative scale for prior regularization. Model parameters are trained only if this is positive. If zero (default), then model log densities will not be computed (guide log densities are never computed).
  • num_particles (int) – The number of particles/samples used to form the gradient estimators. Must be at least 2.
  • max_plate_nesting (int) – Optional bound on max number of nested pyro.plate() contexts. If omitted, this will guess a valid value by running the (model,guide) pair once.
__call__(model, guide, *args, **kwargs)[source]

Computes the surrogate loss that can be differentiated with autograd to produce gradient estimates for the model and guide parameters.

loss(*args, **kwargs)[source]

Not implemented. Added for compatibility with unit tests only.

Discrete Inference

infer_discrete(fn=None, first_available_dim=None, temperature=1)[source]

A poutine that samples discrete sites marked with site["infer"]["enumerate"] = "parallel" from the posterior, conditioned on observations.

Example:

@infer_discrete(first_available_dim=-1, temperature=0)
@config_enumerate
def viterbi_decoder(data, hidden_dim=10):
    transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim)
    means = torch.arange(float(hidden_dim))
    states = [0]
    for t in pyro.markov(range(len(data))):
        states.append(pyro.sample("states_{}".format(t),
                                  dist.Categorical(transition[states[-1]])))
        pyro.sample("obs_{}".format(t),
                    dist.Normal(means[states[-1]], 1.),
                    obs=data[t])
    return states  # returns maximum likelihood states
Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer.
  • temperature (int) – Either 1 (sample via forward-filter backward-sample) or 0 (optimize via Viterbi-like MAP inference). Defaults to 1 (sample).
class TraceEnumSample_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=-1.0)[source]

Bases: pyro.infer.traceenum_elbo.TraceEnum_ELBO

This extends TraceEnum_ELBO to make it cheaper to sample from discrete latent states during SVI.

The following are equivalent but the first is cheaper, sharing work between the computations of loss and z:

# Version 1.
elbo = TraceEnumSample_ELBO(max_plate_nesting=1)
loss = elbo.loss(*args, **kwargs)
z = elbo.sample_saved()

# Version 2.
elbo = TraceEnum_ELBO(max_plate_nesting=1)
loss = elbo.loss(*args, **kwargs)
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
z = infer_discrete(poutine.replay(model, guide_trace),
                   first_available_dim=-2)(*args, **kwargs)
sample_saved()[source]

Generate latent samples while reusing work from SVI.step().

Inference Utilities

class Predictive(model, posterior_samples=None, guide=None, num_samples=None, return_sites=(), parallel=False)[source]

Bases: torch.nn.modules.module.Module

EXPERIMENTAL class used to construct predictive distribution. The predictive distribution is obtained by running the model conditioned on latent samples from posterior_samples. If a guide is provided, then posterior samples from all the latent sites are also returned.

Warning

The interface for the Predictive class is experimental, and might change in the future.

Parameters:
  • model – Python callable containing Pyro primitives.
  • posterior_samples (dict) – dictionary of samples from the posterior.
  • guide (callable) – optional guide to get posterior samples of sites not present in posterior_samples.
  • num_samples (int) – number of samples to draw from the predictive distribution. This argument has no effect if posterior_samples is non-empty, in which case, the leading dimension size of samples in posterior_samples is used.
  • return_sites (list, tuple, or set) – sites to return; by default only sample sites not present in posterior_samples are returned.
  • parallel (bool) – predict in parallel by wrapping the existing model in an outermost plate messenger. Note that this requires that the model has all batch dims correctly annotated via plate. Default is False.
call(*args, **kwargs)[source]

Method that calls forward() and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlike forward(), this method can be traced by torch.jit.trace_module().

Warning

This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue.

forward(*args, **kwargs)[source]

Returns dict of samples from the predictive distribution. By default, only sample sites not contained in posterior_samples are returned. This can be modified by changing the return_sites keyword argument of this Predictive instance.

Parameters:
  • args – model arguments.
  • kwargs – model keyword arguments.
get_samples(*args, **kwargs)[source]
get_vectorized_trace(*args, **kwargs)[source]

Returns a single vectorized trace from the predictive distribution. Note that this requires that the model has all batch dims correctly annotated via plate.

Parameters:
  • args – model arguments.
  • kwargs – model keyword arguments.
class EmpiricalMarginal(trace_posterior, sites=None, validate_args=None)[source]

Bases: pyro.distributions.empirical.Empirical

Marginal distribution over a single site (or multiple, provided they have the same shape) from the TracePosterior’s model.

Note

If multiple sites are specified, they must have the same tensor shape. Samples from each site will be stacked and stored within a single tensor. See Empirical. To hold the marginal distribution of sites having different shapes, use Marginals instead.

Parameters:
  • trace_posterior (TracePosterior) – a TracePosterior instance representing a Monte Carlo posterior.
  • sites (list) – optional list of sites for which we need to generate the marginal distribution.
class Marginals(trace_posterior, sites=None, validate_args=None)[source]

Bases: object

Holds the marginal distribution over one or more sites from the TracePosterior’s model. This is a convenience container class, which can be extended by TracePosterior subclasses. e.g. for implementing diagnostics.

Parameters:
  • trace_posterior (TracePosterior) – a TracePosterior instance representing a Monte Carlo posterior.
  • sites (list) – optional list of sites for which we need to generate the marginal distribution.
empirical

A dictionary of sites’ names and their corresponding EmpiricalMarginal distribution.

Type:OrderedDict
support(flatten=False)[source]

Gets support of this marginal distribution.

Parameters:flatten (bool) – A flag to decide if we want to flatten batch_shape when the marginal distribution is collected from the posterior with num_chains > 1. Defaults to False.
Returns:a dict with keys are sites’ names and values are sites’ supports.
Return type:OrderedDict
class TracePosterior(num_chains=1)[source]

Bases: object

Abstract TracePosterior object from which posterior inference algorithms inherit. When run, collects a bag of execution traces from the approximate posterior. This is designed to be used by other utility classes like EmpiricalMarginal, that need access to the collected execution traces.

information_criterion(pointwise=False)[source]

Computes information criterion of the model. Currently, returns only “Widely Applicable/Watanabe-Akaike Information Criterion” (WAIC) and the corresponding effective number of parameters.

Reference:

[1] Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC, Aki Vehtari, Andrew Gelman, and Jonah Gabry

Parameters:pointwise (bool) – a flag to decide if we want to get a vectorized WAIC or not. When pointwise=False, returns the sum.
Returns:a dictionary containing values of WAIC and its effective number of parameters.
Return type:OrderedDict
marginal(sites=None)[source]

Generates the marginal distribution of this posterior.

Parameters:sites (list) – optional list of sites for which we need to generate the marginal distribution.
Returns:A Marginals class instance.
Return type:Marginals
run(*args, **kwargs)[source]

Calls self._traces to populate execution traces from a stochastic Pyro model.

Parameters:
  • args – optional args taken by self._traces.
  • kwargs – optional keywords args taken by self._traces.
class TracePredictive(model, posterior, num_samples, keep_sites=None)[source]

Bases: pyro.infer.abstract_infer.TracePosterior

Warning

This class is deprecated and will be removed in a future release. Use the Predictive class instead.

Generates and holds traces from the posterior predictive distribution, given model execution traces from the approximate posterior. This is achieved by constraining latent sites to randomly sampled parameter values from the model execution traces and running the model forward to generate traces with new response (“_RETURN”) sites. :param model: arbitrary Python callable containing Pyro primitives. :param TracePosterior posterior: trace posterior instance holding samples from the model’s approximate posterior. :param int num_samples: number of samples to generate. :param keep_sites: The sites which should be sampled from posterior distribution (default: all)

marginal(sites=None)[source]

Gets marginal distribution for this predictive posterior distribution.