SVI¶
- class SVI(model, guide, optim, loss, loss_and_grads=None, num_samples=0, num_steps=0, **kwargs)[source]¶
Bases:
pyro.infer.abstract_infer.TracePosterior
- Parameters
model – the model (callable containing Pyro primitives)
guide – the guide (callable containing Pyro primitives)
optim (PyroOptim) – a wrapper a for a PyTorch optimizer
loss (pyro.infer.elbo.ELBO) – an instance of a subclass of
ELBO
. Pyro provides three built-in losses:Trace_ELBO
,TraceGraph_ELBO
, andTraceEnum_ELBO
. See theELBO
docs to learn how to implement a custom loss.num_samples – (DEPRECATED) the number of samples for Monte Carlo posterior approximation
num_steps – (DEPRECATED) the number of optimization steps to take in
run()
A unified interface for stochastic variational inference in Pyro. The most commonly used loss is
loss=Trace_ELBO()
. See the tutorial SVI Part I for a discussion.- evaluate_loss(*args, **kwargs)[source]¶
- Returns
estimate of the loss
- Return type
Evaluate the loss function. Any args or kwargs are passed to the model and guide.
- run(*args, **kwargs)[source]¶
Warning
This method is deprecated, and will be removed in a future release. For inference, use
step()
directly, and for predictions, use thePredictive
class.
ELBO¶
- class ELBOModule(model: torch.nn.modules.module.Module, guide: torch.nn.modules.module.Module, elbo: pyro.infer.elbo.ELBO)[source]¶
- class ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
object
ELBO
is the top-level interface for stochastic variational inference via optimization of the evidence lower bound.Most users will not interact with this base class
ELBO
directly; instead they will create instances of derived classes:Trace_ELBO
,TraceGraph_ELBO
, orTraceEnum_ELBO
.Note
Derived classes now provide a more idiomatic PyTorch interface via
__call__()
for (model, guide) pairs that areModule
s, which is useful for integrating Pyro’s variational inference tooling with standard PyTorch interfaces likeOptimizer
s and the large ecosystem of libraries like PyTorch Lightning and the PyTorch JIT that work with these interfaces:model = Model() guide = pyro.infer.autoguide.AutoNormal(model) elbo_ = pyro.infer.Trace_ELBO(num_particles=10) # Fix the model/guide pair elbo = elbo_(model, guide) # perform any data-dependent initialization elbo(data) optim = torch.optim.Adam(elbo.parameters(), lr=0.001) for _ in range(100): optim.zero_grad() loss = elbo(data) loss.backward() optim.step()
Note that Pyro’s global parameter store may cause this new interface to behave unexpectedly relative to standard PyTorch when working with
PyroModule
s.Users are therefore strongly encouraged to use this interface in conjunction with
pyro.settings.set(module_local_params=True)
which will override the default implicit sharing of parameters acrossPyroModule
instances.- Parameters
num_particles – The number of particles/samples used to form the ELBO (gradient) estimators.
max_plate_nesting (int) – Optional bound on max number of nested
pyro.plate()
contexts. This is only required when enumerating over sample sites in parallel, e.g. if a site setsinfer={"enumerate": "parallel"}
. If omitted, ELBO may guess a valid value by running the (model,guide) pair once, however this guess may be incorrect if model or guide structure is dynamic.vectorize_particles (bool) – Whether to vectorize the ELBO computation over num_particles. Defaults to False. This requires static structure in model and guide.
strict_enumeration_warning (bool) – Whether to warn about possible misuse of enumeration, i.e. that
pyro.infer.traceenum_elbo.TraceEnum_ELBO
is used iff there are enumerated sample sites.ignore_jit_warnings (bool) – Flag to ignore warnings from the JIT tracer. When this is True, all
torch.jit.TracerWarning
will be ignored. Defaults to False.jit_options (bool) – Optional dict of options to pass to
torch.jit.trace()
, e.g.{"check_trace": True}
.retain_graph (bool) – Whether to retain autograd graph during an SVI step. Defaults to None (False).
tail_adaptive_beta (float) – Exponent beta with
-1.0 <= beta < 0.0
for use with TraceTailAdaptive_ELBO.
References
[1] Automated Variational Inference in Probabilistic Programming David Wingate, Theo Weber
[2] Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei
- class Trace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
pyro.infer.elbo.ELBO
A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and [2]. There are no restrictions on the dependency structure of the model or the guide. The gradient estimator includes partial Rao-Blackwellization for reducing the variance of the estimator when non-reparameterizable random variables are present. The Rao-Blackwellization is partial in that it only uses conditional independence information that is marked by
plate
contexts. For more fine-grained Rao-Blackwellization, seeTraceGraph_ELBO
.References
- [1] Automated Variational Inference in Probabilistic Programming,
David Wingate, Theo Weber
- [2] Black Box Variational Inference,
Rajesh Ranganath, Sean Gerrish, David M. Blei
- loss(model, guide, *args, **kwargs)[source]¶
- Returns
returns an estimate of the ELBO
- Return type
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
- class JitTrace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
pyro.infer.trace_elbo.Trace_ELBO
Like
Trace_ELBO
but usespyro.ops.jit.compile()
to compileloss_and_grads()
.This works only for a limited set of models:
Models must have static structure.
Models must not depend on any global data (except the param store).
All model inputs that are tensors must be passed in via
*args
.All model inputs that are not tensors must be passed in via
**kwargs
, and compilation will be triggered once per unique**kwargs
.
- class TrackNonReparam[source]¶
Bases:
pyro.poutine.messenger.Messenger
Track non-reparameterizable sample sites.
References:
- Nonstandard Interpretations of Probabilistic Programs for Efficient Inference,
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind
Example:
>>> import torch >>> import pyro >>> import pyro.distributions as dist >>> from pyro.infer.tracegraph_elbo import TrackNonReparam >>> from pyro.ops.provenance import get_provenance >>> from pyro.poutine import trace >>> def model(): ... probs_a = torch.tensor([0.3, 0.7]) ... probs_b = torch.tensor([[0.1, 0.9], [0.8, 0.2]]) ... probs_c = torch.tensor([[0.5, 0.5], [0.6, 0.4]]) ... a = pyro.sample("a", dist.Categorical(probs_a)) ... b = pyro.sample("b", dist.Categorical(probs_b[a])) ... pyro.sample("c", dist.Categorical(probs_c[b]), obs=torch.tensor(0)) >>> with TrackNonReparam(): ... model_tr = trace(model).get_trace() >>> model_tr.compute_log_prob() >>> print(get_provenance(model_tr.nodes["a"]["log_prob"])) frozenset({'a'}) >>> print(get_provenance(model_tr.nodes["b"]["log_prob"])) frozenset({'b', 'a'}) >>> print(get_provenance(model_tr.nodes["c"]["log_prob"])) frozenset({'b', 'a'})
- class TraceGraph_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
pyro.infer.elbo.ELBO
A TraceGraph implementation of ELBO-based SVI. The gradient estimator is constructed along the lines of reference [1] specialized to the case of the ELBO. It supports arbitrary dependency structure for the model and guide as well as baselines for non-reparameterizable random variables. Fine-grained conditional dependency information as recorded in the
Trace
is used to reduce the variance of the gradient estimator. In particular provenance tracking [3] is used to find thecost
terms that depend on each non-reparameterizable sample site.References
- [1] Gradient Estimation Using Stochastic Computation Graphs,
John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel
- [2] Neural Variational Inference and Learning in Belief Networks
Andriy Mnih, Karol Gregor
- [3] Nonstandard Interpretations of Probabilistic Programs for Efficient Inference,
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind
- loss(model, guide, *args, **kwargs)[source]¶
- Returns
returns an estimate of the ELBO
- Return type
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
- loss_and_grads(model, guide, *args, **kwargs)[source]¶
- Returns
returns an estimate of the ELBO
- Return type
Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. If baselines are present, a baseline loss is also constructed and differentiated.
- class JitTraceGraph_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
pyro.infer.tracegraph_elbo.TraceGraph_ELBO
Like
TraceGraph_ELBO
but usestorch.jit.trace()
to compileloss_and_grads()
.This works only for a limited set of models:
Models must have static structure.
Models must not depend on any global data (except the param store).
All model inputs that are tensors must be passed in via
*args
.All model inputs that are not tensors must be passed in via
**kwargs
, and compilation will be triggered once per unique**kwargs
.
- class BackwardSampleMessenger(enum_trace, guide_trace)[source]¶
Bases:
pyro.poutine.messenger.Messenger
Implements forward filtering / backward sampling for sampling from the joint posterior distribution
- class TraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
pyro.infer.elbo.ELBO
A trace implementation of ELBO-based SVI that supports - exhaustive enumeration over discrete sample sites, and - local parallel sampling over any sample site in the guide.
To enumerate over a sample site in the
guide
, mark the site with eitherinfer={'enumerate': 'sequential'}
orinfer={'enumerate': 'parallel'}
. To configure all guide sites at once, useconfig_enumerate()
. To enumerate over a sample site in themodel
, mark the siteinfer={'enumerate': 'parallel'}
and ensure the site does not appear in theguide
.This assumes restricted dependency structure on the model and guide: variables outside of an
plate
can never depend on variables inside thatplate
.- loss(model, guide, *args, **kwargs)[source]¶
- Returns
an estimate of the ELBO
- Return type
Estimates the ELBO using
num_particles
many samples (particles).
- differentiable_loss(model, guide, *args, **kwargs)[source]¶
- Returns
a differentiable estimate of the ELBO
- Return type
- Raises
ValueError – if the ELBO is not differentiable (e.g. is identically zero)
Estimates a differentiable ELBO using
num_particles
many samples (particles). The result should be infinitely differentiable (as long as underlying derivatives have been implemented).
- loss_and_grads(model, guide, *args, **kwargs)[source]¶
- Returns
an estimate of the ELBO
- Return type
Estimates the ELBO using
num_particles
many samples (particles). Performs backward on the ELBO of each particle.
- class JitTraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
pyro.infer.traceenum_elbo.TraceEnum_ELBO
Like
TraceEnum_ELBO
but usespyro.ops.jit.compile()
to compileloss_and_grads()
.This works only for a limited set of models:
Models must have static structure.
Models must not depend on any global data (except the param store).
All model inputs that are tensors must be passed in via
*args
.All model inputs that are not tensors must be passed in via
**kwargs
, and compilation will be triggered once per unique**kwargs
.
- class TraceMeanField_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
pyro.infer.trace_elbo.Trace_ELBO
A trace implementation of ELBO-based SVI. This is currently the only ELBO estimator in Pyro that uses analytic KL divergences when those are available.
In contrast to, e.g.,
TraceGraph_ELBO
andTrace_ELBO
this estimator places restrictions on the dependency structure of the model and guide. In particular it assumes that the guide has a mean-field structure, i.e. that it factorizes across the different latent variables present in the guide. It also assumes that all of the latent variables in the guide are reparameterized. This latter condition is satisfied for, e.g., the Normal distribution but is not satisfied for, e.g., the Categorical distribution.Warning
This estimator may give incorrect results if the mean-field condition is not satisfied.
Note for advanced users:
The mean field condition is a sufficient but not necessary condition for this estimator to be correct. The precise condition is that for every latent variable z in the guide, its parents in the model must not include any latent variables that are descendants of z in the guide. Here ‘parents in the model’ and ‘descendants in the guide’ is with respect to the corresponding (statistical) dependency structure. For example, this condition is always satisfied if the model and guide have identical dependency structures.
- class JitTraceMeanField_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
pyro.infer.trace_mean_field_elbo.TraceMeanField_ELBO
Like
TraceMeanField_ELBO
but usespyro.ops.jit.trace()
to compileloss_and_grads()
.This works only for a limited set of models:
Models must have static structure.
Models must not depend on any global data (except the param store).
All model inputs that are tensors must be passed in via
*args
.All model inputs that are not tensors must be passed in via
**kwargs
, and compilation will be triggered once per unique**kwargs
.
- class TraceTailAdaptive_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
pyro.infer.trace_elbo.Trace_ELBO
Interface for Stochastic Variational Inference with an adaptive f-divergence as described in ref. [1]. Users should specify num_particles > 1 and vectorize_particles==True. The argument tail_adaptive_beta can be specified to modify how the adaptive f-divergence is constructed. See reference for details.
Note that this interface does not support computing the varational objective itself; rather it only supports computing gradients of the variational objective. Consequently, one might want to use another SVI interface (e.g. RenyiELBO) in order to monitor convergence.
Note that this interface only supports models in which all the latent variables are fully reparameterized. It also does not support data subsampling.
References [1] “Variational Inference with Tail-adaptive f-Divergence”, Dilin Wang, Hao Liu, Qiang Liu, NeurIPS 2018 https://papers.nips.cc/paper/7816-variational-inference-with-tail-adaptive-f-divergence
- class RenyiELBO(alpha=0, num_particles=2, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True)[source]¶
Bases:
pyro.infer.elbo.ELBO
An implementation of Renyi’s \(\alpha\)-divergence variational inference following reference [1].
In order for the objective to be a strict lower bound, we require \(\alpha \ge 0\). Note, however, that according to reference [1], depending on the dataset \(\alpha < 0\) might give better results. In the special case \(\alpha = 0\), the objective function is that of the important weighted autoencoder derived in reference [2].
Note
Setting \(\alpha < 1\) gives a better bound than the usual ELBO. For \(\alpha = 1\), it is better to use
Trace_ELBO
class because it helps reduce variances of gradient estimations.- Parameters
alpha (float) – The order of \(\alpha\)-divergence. Here \(\alpha \neq 1\). Default is 0.
num_particles – The number of particles/samples used to form the objective (gradient) estimator. Default is 2.
max_plate_nesting (int) – Bound on max number of nested
pyro.plate()
contexts. Default is infinity.strict_enumeration_warning (bool) – Whether to warn about possible misuse of enumeration, i.e. that
TraceEnum_ELBO
is used iff there are enumerated sample sites.
References:
- [1] Renyi Divergence Variational Inference,
Yingzhen Li, Richard E. Turner
- [2] Importance Weighted Autoencoders,
Yuri Burda, Roger Grosse, Ruslan Salakhutdinov
- class TraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
pyro.infer.elbo.ELBO
A trace-based implementation of Tensor Monte Carlo [1] by way of Tensor Variable Elimination [2] that supports: - local parallel sampling over any sample site in the model or guide - exhaustive enumeration over any sample site in the model or guide
To take multiple samples, mark the site with
infer={'enumerate': 'parallel', 'num_samples': N}
. To configure all sites in a model or guide at once, useconfig_enumerate()
. To enumerate or sample a sample site in themodel
, mark the site and ensure the site does not appear in theguide
.This assumes restricted dependency structure on the model and guide: variables outside of an
plate
can never depend on variables inside thatplate
.References
- [1] Tensor Monte Carlo: Particle Methods for the GPU Era,
Laurence Aitchison (2018)
- [2] Tensor Variable Elimination for Plated Factor Graphs,
Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu, Neeraj Pradhan, Alexander Rush, Noah Goodman (2019)
- differentiable_loss(model, guide, *args, **kwargs)[source]¶
- Returns
a differentiable estimate of the marginal log-likelihood
- Return type
- Raises
ValueError – if the ELBO is not differentiable (e.g. is identically zero)
Computes a differentiable TMC estimate using
num_particles
many samples (particles). The result should be infinitely differentiable (as long as underlying derivatives have been implemented).
Importance¶
- class Importance(model, guide=None, num_samples=None)[source]¶
Bases:
pyro.infer.abstract_infer.TracePosterior
,pyro.infer.importance.LogWeightsMixin
- Parameters
model – probabilistic model defined as a function
guide – guide used for sampling defined as a function
num_samples – number of samples to draw from the guide (default 10)
This method performs posterior inference by importance sampling using the guide as the proposal distribution. If no guide is provided, it defaults to proposing from the model’s prior.
- log_weights: Union[List[Union[float, torch.Tensor]], torch.Tensor]¶
- class LogWeightsMixin[source]¶
Bases:
object
Mixin class to compute analytics from a
.log_weights
attribute.- get_log_normalizer()[source]¶
Estimator of the normalizing constant of the target distribution. (mean of the unnormalized weights)
- log_weights: Union[List[Union[float, torch.Tensor]], torch.Tensor]¶
- psis_diagnostic(model, guide, *args, **kwargs)[source]¶
Computes the Pareto tail index k for a model/guide pair using the technique described in [1], which builds on previous work in [2]. If \(0 < k < 0.5\) the guide is a good approximation to the model posterior, in the sense described in [1]. If \(0.5 \le k \le 0.7\), the guide provides a suboptimal approximation to the posterior, but may still be useful in practice. If \(k > 0.7\) the guide program provides a poor approximation to the full posterior, and caution should be used when using the guide. Note, however, that a guide may be a poor fit to the full posterior while still yielding reasonable model predictions. If \(k < 0.0\) the importance weights corresponding to the model and guide appear to be bounded from above; this would be a bizarre outcome for a guide trained via ELBO maximization. Please see [1] for a more complete discussion of how the tail index k should be interpreted.
Please be advised that a large number of samples may be required for an accurate estimate of k.
Note that we assume that the model and guide are both vectorized and have static structure. As is canonical in Pyro, the args and kwargs are passed to the model and guide.
References [1] ‘Yes, but Did It Work?: Evaluating Variational Inference.’ Yuling Yao, Aki Vehtari, Daniel Simpson, Andrew Gelman [2] ‘Pareto Smoothed Importance Sampling.’ Aki Vehtari, Andrew Gelman, Jonah Gabry
- Parameters
model (callable) – the model program.
guide (callable) – the guide program.
num_particles (int) – the total number of times we run the model and guide in order to compute the diagnostic. defaults to 1000.
max_simultaneous_particles – the maximum number of simultaneous samples drawn from the model and guide. defaults to num_particles. num_particles must be divisible by max_simultaneous_particles. compute the diagnostic. defaults to 1000.
max_plate_nesting (int) – optional bound on max number of nested
pyro.plate()
contexts in the model/guide. defaults to 7.
- Returns float
the PSIS diagnostic k
- vectorized_importance_weights(model, guide, *args, **kwargs)[source]¶
- Parameters
model – probabilistic model defined as a function
guide – guide used for sampling defined as a function
num_samples – number of samples to draw from the guide (default 1)
max_plate_nesting (int) – Bound on max number of nested
pyro.plate()
contexts.normalized (bool) – set to True to return self-normalized importance weights
- Returns
returns a
(num_samples,)
-shaped tensor of importance weights and the model and guide traces that produced them
Vectorized computation of importance weights for models with static structure:
log_weights, model_trace, guide_trace = \ vectorized_importance_weights(model, guide, *args, num_samples=1000, max_plate_nesting=4, normalized=False)
Reweighted Wake-Sleep¶
- class ReweightedWakeSleep(num_particles=2, insomnia=1.0, model_has_params=True, num_sleep_particles=None, vectorize_particles=True, max_plate_nesting=inf, strict_enumeration_warning=True)[source]¶
Bases:
pyro.infer.elbo.ELBO
An implementation of Reweighted Wake Sleep following reference [1].
Note
Sampling and log_prob evaluation asymptotic complexity:
- Using wake-theta and/or wake-phi
O(num_particles) samples from guide, O(num_particles) log_prob evaluations of model and guide
- Using sleep-phi
O(num_sleep_particles) samples from model, O(num_sleep_particles) log_prob evaluations of guide
- if 1) and 2) are combined,
O(num_particles) samples from the guide, O(num_sleep_particles) from the model, O(num_particles + num_sleep_particles) log_prob evaluations of the guide, and O(num_particles) evaluations of the model
Note
This is particularly useful for models with stochastic branching, as described in [2].
Note
This returns _two_ losses, one each for (a) the model parameters (theta), computed using the iwae objective, and (b) the guide parameters (phi), computed using (a combination of) the csis objective and a self-normalized importance-sampled version of the csis objective.
Note
In order to enable computing the sleep-phi terms, the guide program must have its observations explicitly passed in through the keyworded argument observations. Where the value of the observations is unknown during definition, such as for amortized variational inference, it may be given a default argument as observations=None, and the correct value supplied during learning through svi.step(observations=…).
Warning
Mini-batch training is not supported yet.
- Parameters
num_particles (int) – The number of particles/samples used to form the objective (gradient) estimator. Default is 2.
insomnia – The scaling between the wake-phi and sleep-phi terms. Default is 1.0 [wake-phi]
model_has_params (bool) – Indicate if model has learnable params. Useful in avoiding extra computation when running in pure sleep mode [csis]. Default is True.
num_sleep_particles (int) – The number of particles used to form the sleep-phi estimator. Matches num_particles by default.
vectorize_particles (bool) – Whether the traces should be vectorised across num_particles. Default is True.
max_plate_nesting (int) – Bound on max number of nested
pyro.plate()
contexts. Default is infinity.strict_enumeration_warning (bool) – Whether to warn about possible misuse of enumeration, i.e. that
TraceEnum_ELBO
is used iff there are enumerated sample sites.
References:
- [1] Reweighted Wake-Sleep,
Jörg Bornschein, Yoshua Bengio
- [2] Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow,
Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood
Sequential Monte Carlo¶
- exception SMCFailed[source]¶
Bases:
ValueError
Exception raised when
SMCFilter
fails to find any hypothesis with nonzero probability.
- class SMCFilter(model, guide, num_particles, max_plate_nesting, *, ess_threshold=0.5)[source]¶
Bases:
object
SMCFilter
is the top-level interface for filtering via sequential monte carlo.The model and guide should be objects with two methods:
.init(state, ...)
and.step(state, ...)
, intended to be called first withinit()
, then withstep()
repeatedly. These two methods should have the same signature asSMCFilter
‘sinit()
andstep()
of this class, but with an extra first argumentstate
that should be used to store all tensors that depend on sampled variables. Thestate
will be a dict-like object,SMCState
, with arbitrary keys andtorch.Tensor
values. Models can read and writestate
but guides can only read from it.Inference complexity is
O(len(state) * num_time_steps)
, so to avoid quadratic complexity in Markov models, ensure thatstate
has fixed size.- Parameters
model (object) – probabilistic model with
init
andstep
methodsguide (object) – guide used for sampling, with
init
andstep
methodsnum_particles (int) – The number of particles used to form the distribution.
max_plate_nesting (int) – Bound on max number of nested
pyro.plate()
contexts.ess_threshold (float) – Effective sample size threshold for deciding when to importance resample: resampling occurs when
ess < ess_threshold * num_particles
.
- get_empirical()[source]¶
- Returns
a marginal distribution over all state tensors.
- Return type
a dictionary with keys which are latent variables and values which are
Empirical
objects.
- class SMCState(num_particles)[source]¶
Bases:
dict
Dictionary-like object to hold a vectorized collection of tensors to represent all state during inference with
SMCFilter
. During inference, theSMCFilter
resample these tensors.Keys may have arbitrary hashable type. Values must be
torch.Tensor
s.- Parameters
num_particles (int) –
Stein Methods¶
- class IMQSteinKernel(alpha=0.5, beta=- 0.5, bandwidth_factor=None)[source]¶
Bases:
pyro.infer.svgd.SteinKernel
An IMQ (inverse multi-quadratic) kernel for use in the SVGD inference algorithm [1]. The bandwidth of the kernel is chosen from the particles using a simple heuristic as in reference [2]. The kernel takes the form
\(K(x, y) = (\alpha + ||x-y||^2/h)^{\beta}\)
where \(\alpha\) and \(\beta\) are user-specified parameters and \(h\) is the bandwidth.
- Parameters
- Variables
bandwidth_factor (float) – Property that controls the factor by which to scale the bandwidth at each iteration.
References
[1] “Stein Points,” Wilson Ye Chen, Lester Mackey, Jackson Gorham, Francois-Xavier Briol, Chris. J. Oates. [2] “Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm,” Qiang Liu, Dilin Wang
- property bandwidth_factor¶
- class RBFSteinKernel(bandwidth_factor=None)[source]¶
Bases:
pyro.infer.svgd.SteinKernel
A RBF kernel for use in the SVGD inference algorithm. The bandwidth of the kernel is chosen from the particles using a simple heuristic as in reference [1].
- Parameters
bandwidth_factor (float) – Optional factor by which to scale the bandwidth, defaults to 1.0.
- Variables
bandwidth_factor (float) – Property that controls the factor by which to scale the bandwidth at each iteration.
References
- [1] “Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm,”
Qiang Liu, Dilin Wang
- property bandwidth_factor¶
- class SVGD(model, kernel, optim, num_particles, max_plate_nesting, mode='univariate')[source]¶
Bases:
object
A basic implementation of Stein Variational Gradient Descent as described in reference [1].
- Parameters
model – The model (callable containing Pyro primitives). Model must be fully vectorized and may only contain continuous latent variables.
kernel – a SVGD compatible kernel like
RBFSteinKernel
.optim (pyro.optim.PyroOptim) – A wrapper for a PyTorch optimizer.
num_particles (int) – The number of particles used in SVGD.
max_plate_nesting (int) – The max number of nested
pyro.plate()
contexts in the model.mode (str) – Whether to use a Kernelized Stein Discrepancy that makes use of multivariate test functions (as in [1]) or univariate test functions (as in [2]). Defaults to univariate.
Example usage:
from pyro.infer import SVGD, RBFSteinKernel from pyro.optim import Adam kernel = RBFSteinKernel() adam = Adam({"lr": 0.1}) svgd = SVGD(model, kernel, adam, num_particles=50, max_plate_nesting=0) for step in range(500): svgd.step(model_arg1, model_arg2) final_particles = svgd.get_named_particles()
References
- [1] “Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm,”
Qiang Liu, Dilin Wang
- [2] “Kernelized Complete Conditional Stein Discrepancy,”
Raghav Singhal, Saad Lahlou, Rajesh Ranganath
- class SteinKernel[source]¶
Bases:
object
Abstract class for kernels used in the
SVGD
inference algorithm.- abstract log_kernel_and_grad(particles)[source]¶
Compute the component kernels and their gradients.
- Parameters
particles – a tensor with shape (N, D)
- Returns
A pair (log_kernel, kernel_grad) where log_kernel is a (N, N, D)-shaped tensor equal to the logarithm of the kernel and kernel_grad is a (N, N, D)-shaped tensor where the entry (n, m, d) represents the derivative of log_kernel w.r.t. x_{m,d}, where x_{m,d} is the d^th dimension of particle m.
Likelihood free methods¶
- class EnergyDistance(beta=1.0, prior_scale=0.0, num_particles=2, max_plate_nesting=inf)[source]¶
Bases:
object
Posterior predictive energy distance [1,2] with optional Bayesian regularization by the prior.
Let p(x,z)=p(z) p(x|z) be the model, q(z|x) be the guide. Then given data x and drawing an iid pair of samples \((Z,X)\) and \((Z',X')\) (where Z is latent and X is the posterior predictive),
\[\begin{split}& Z \sim q(z|x); \quad X \sim p(x|Z) \\ & Z' \sim q(z|x); \quad X' \sim p(x|Z') \\ & loss = \mathbb E_X \|X-x\|^\beta - \frac 1 2 \mathbb E_{X,X'}\|X-X'\|^\beta - \lambda \mathbb E_Z \log p(Z)\end{split}\]This is a likelihood-free inference algorithm, and can be used for likelihoods without tractable density functions. The \(\beta\) energy distance is a robust loss functions, and is well defined for any distribution with finite fractional moment \(\mathbb E[\|X\|^\beta]\).
This requires static model structure, a fully reparametrized guide, and reparametrized likelihood distributions in the model. Model latent distributions may be non-reparametrized.
References
- [1] Gabor J. Szekely, Maria L. Rizzo (2003)
Energy Statistics: A Class of Statistics Based on Distances.
- [2] Tilmann Gneiting, Adrian E. Raftery (2007)
Strictly Proper Scoring Rules, Prediction, and Estimation. https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf
- Parameters
beta (float) – Exponent \(\beta\) from [1,2]. The loss function is strictly proper for distributions with finite \(beta\)-absolute moment \(E[\|X\|^\beta]\). Thus for heavy tailed distributions
beta
should be small, e.g. forCauchy
distributions, \(\beta<1\) is strictly proper. Defaults to 1. Must be in the open interval (0,2).prior_scale (float) – Nonnegative scale for prior regularization. Model parameters are trained only if this is positive. If zero (default), then model log densities will not be computed (guide log densities are never computed).
num_particles (int) – The number of particles/samples used to form the gradient estimators. Must be at least 2.
max_plate_nesting (int) – Optional bound on max number of nested
pyro.plate()
contexts. If omitted, this will guess a valid value by running the (model,guide) pair once.
Discrete Inference¶
- infer_discrete(fn=None, first_available_dim=None, temperature=1, *, strict_enumeration_warning=True)[source]¶
A poutine that samples discrete sites marked with
site["infer"]["enumerate"] = "parallel"
from the posterior, conditioned on observations.Example:
@infer_discrete(first_available_dim=-1, temperature=0) @config_enumerate def viterbi_decoder(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim) means = torch.arange(float(hidden_dim)) states = [0] for t in pyro.markov(range(len(data))): states.append(pyro.sample("states_{}".format(t), dist.Categorical(transition[states[-1]]))) pyro.sample("obs_{}".format(t), dist.Normal(means[states[-1]], 1.), obs=data[t]) return states # returns maximum likelihood states
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer.
temperature (int) – Either 1 (sample via forward-filter backward-sample) or 0 (optimize via Viterbi-like MAP inference). Defaults to 1 (sample).
strict_enumeration_warning (bool) – Whether to warn in case no enumerated sample sites are found. Defalts to True.
- class TraceEnumSample_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
Bases:
pyro.infer.traceenum_elbo.TraceEnum_ELBO
This extends
TraceEnum_ELBO
to make it cheaper to sample from discrete latent states during SVI.The following are equivalent but the first is cheaper, sharing work between the computations of
loss
andz
:# Version 1. elbo = TraceEnumSample_ELBO(max_plate_nesting=1) loss = elbo.loss(*args, **kwargs) z = elbo.sample_saved() # Version 2. elbo = TraceEnum_ELBO(max_plate_nesting=1) loss = elbo.loss(*args, **kwargs) guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) z = infer_discrete(poutine.replay(model, guide_trace), first_available_dim=-2)(*args, **kwargs)
Prediction utilities¶
- class MHResampler(sampler: Callable, source_samples_slice: slice = slice(None, 0, None), stored_samples_slice: slice = slice(None, 0, None))[source]¶
Bases:
torch.nn.modules.module.Module
Resampler for weighed samples that generates equally weighed samples from the distribution specified by the weighed samples
sampler
.The resampling is based on the Metropolis-Hastings algorithm. Given an initial sample \(x\) subsequent samples are generated by:
Sampling from the
guide
a new sample candidate \(x'\) with probability \(g(x')\).Calculate an acceptance probability \(A(x', x) = \min\left(1, \frac{P(x')}{P(x)} \frac{g(x)}{g(x')}\right)\) with \(P\) being the
model
.With probability \(A(x', x)\) accept the new sample candidate \(x'\) as the next sample, otherwise set the current sample \(x\) as the next sample.
The above is the Metropolis-Hastings algorithm with the new sample candidate proposal distribution being equal to the
guide
and independent of the current sample such that \(g(x')=g(x' \mid x)\).- Parameters
sampler (callable) – When called returns
WeighedPredictiveResults
.source_samples_slice (slice) – Select source samples for storage (default is slice(0), i.e. none).
stored_samples_slice (slice) – Select output samples for storage (default is slice(0), i.e. none).
The typical use case of
MHResampler
would be to convert weighed samples generated byWeighedPredictive
into equally weighed samples from the target distribution. Each time an instance ofMHResampler
is called it returns a new set of samples, with the samples generated by the first call being distributed according to theguide
, and with each subsequent call the distribution of the samples becomes closer to that of the posterior predictive disdtribution. It might take some experimentation in order to find out in each case how many times one would need to call an instance ofMHResampler
in order to be close enough to the posterior predictive distribution.Example:
def model(): ... def guide(): ... def conditioned_model(): ... # Fit guide elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=3.0)), elbo) for i in range(num_svi_steps): svi.step() # Create callable that returns weighed samples posterior_predictive = WeighedPredictive(model, guide=guide, num_samples=num_samples, parallel=parallel, return_sites=["_RETURN"]) prob = 0.95 weighed_samples = posterior_predictive(model_guide=conditioned_model) # Calculate quantile directly from weighed samples weighed_samples_quantile = weighed_quantile(weighed_samples.samples['_RETURN'], [prob], weighed_samples.log_weights)[0] resampler = MHResampler(posterior_predictive) num_mh_steps = 10 for mh_step_count in range(num_mh_steps): resampled_weighed_samples = resampler(model_guide=conditioned_model) # Calculate quantile from resampled weighed samples (samples are equally weighed) resampled_weighed_samples_quantile = quantile(resampled_weighed_samples.samples[`_RETURN`], [prob])[0] # Quantiles calculated using both methods should be identical assert_close(weighed_samples_quantile, resampled_weighed_samples_quantile, rtol=0.01)
Notes on Sampler Behavior:
In case the
guide
perfectly tracks themodel
this sampler will do nothing as the acceptance probability \(A(x', x)\) will always be one.Furtheremore, if the guide is approximately separable, i.e. \(g(z_A, z_B) \approx g_A(z_A) g_B(z_B)\), with \(g_A(z_A)\) pefectly tracking the
model
and \(g_B(z_B)\) poorly tracking themodel
, quantiles of \(z_A\) calculated from samples taken fromMHResampler
, will have much lower variance then quantiles of \(z_A\) calculated by usingweighed_quantile
, as the effective sample size of the calculation usingweighed_quantile
will be low due to \(g_B(z_B)\) poorly tracking themodel
, whereas when usingMHResampler
the poormodel
tracking of \(g_B(z_B)\) has negligible affect on the effective sample size of \(z_A\) samples.
- forward(*args, **kwargs)[source]¶
Perform single resampling step. Returns
WeighedPredictiveResults
- get_min_sample_transition_count()[source]¶
Return transition count of sample with minimal amount of transitions.
- get_samples(samples)[source]¶
Return samples that were sampled during execution of the Metropolis-Hastings algorithm.
- get_source_samples()[source]¶
Return source samples that were the input to the Metropolis-Hastings algorithm.
- class Predictive(model, posterior_samples=None, guide=None, num_samples=None, return_sites=(), parallel=False)[source]¶
Bases:
torch.nn.modules.module.Module
EXPERIMENTAL class used to construct predictive distribution. The predictive distribution is obtained by running the model conditioned on latent samples from posterior_samples. If a guide is provided, then posterior samples from all the latent sites are also returned.
Warning
The interface for the
Predictive
class is experimental, and might change in the future.- Parameters
model – Python callable containing Pyro primitives.
posterior_samples (dict) – dictionary of samples from the posterior.
guide (callable) – optional guide to get posterior samples of sites not present in posterior_samples.
num_samples (int) – number of samples to draw from the predictive distribution. This argument has no effect if
posterior_samples
is non-empty, in which case, the leading dimension size of samples inposterior_samples
is used.return_sites (list, tuple, or set) – sites to return; by default only sample sites not present in posterior_samples are returned.
parallel (bool) – predict in parallel by wrapping the existing model in an outermost plate messenger. Note that this requires that the model has all batch dims correctly annotated via
plate
. Default is False.
- call(*args, **kwargs)[source]¶
Method that calls
forward()
and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlikeforward()
, this method can be traced bytorch.jit.trace_module()
.Warning
This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue.
- forward(*args, **kwargs)[source]¶
Returns dict of samples from the predictive distribution. By default, only sample sites not contained in posterior_samples are returned. This can be modified by changing the return_sites keyword argument of this
Predictive
instance.Note
This method is used internally by
Module
. Users should instead use__call__()
as inPredictive(model)(*args, **kwargs)
.- Parameters
args – model arguments.
kwargs – model keyword arguments.
- class WeighedPredictive(model, posterior_samples=None, guide=None, num_samples=None, return_sites=(), parallel=False)[source]¶
Bases:
pyro.infer.predictive.Predictive
Class used to construct a weighed predictive distribution that is based on the same initialization interface as
Predictive
.The methods .forward and .call can be called with an additional keyword argument
model_guide
which is the model used to create and optimize the guide (if not providedmodel_guide
defaults toself.model
), and they return both samples and log_weights.The weights are calculated as the per sample gap between the model_guide log-probability and the guide log-probability (a guide must always be provided).
A typical use case would be based on a
model
\(p(x,z)=p(x|z)p(z)\) andguide
\(q(z)\) that has already been fitted to the model given observations \(p(X_{obs},z)\), both of which are provided at itialization ofWeighedPredictive
(same as you would do withPredictive
). When calling an instance ofWeighedPredictive
we provide the model given observations \(p(X_{obs},z)\) as the keyword argumentmodel_guide
. The resulting output would be the usual samples \(p(x|z)q(z)\) returned byPredictive
, along with per sample weights \(p(X_{obs},z)/q(z)\). The samples and weights can be fed intoweighed_quantile
in order to obtain the true quantiles of the resulting distribution.Note that the
model
can be more elaborate with sample sites \(y\) that are not observed and are not part of the guide, if the samples sites \(y\) are sampled after the observations and the latent variables sampled by the guide, such that \(p(x,y,z)=p(y|x,z)p(x|z)p(z)\) where each element in the product represents a set ofpyro.sample
statements.- call(*args, **kwargs)[source]¶
Method .call that is backwards compatible with the same method found in
Predictive
but can be called with an additional keyword argument model_guide which is the model used to create and optimize the guide.Returns
WeighedPredictiveResults
which has attributes.samples
and per sample weights.log_weights
.
- forward(*args, **kwargs)[source]¶
Method .forward that is backwards compatible with the same method found in
Predictive
but can be called with an additional keyword argument model_guide which is the model used to create and optimize the guide.Returns
WeighedPredictiveResults
which has attributes.samples
and per sample weights.log_weights
.
- class WeighedPredictiveResults(samples: Union[dict, tuple], log_weights: torch.Tensor, guide_log_prob: torch.Tensor, model_log_prob: torch.Tensor)[source]¶
Bases:
pyro.infer.importance.LogWeightsMixin
,pyro.infer.util.CloneMixin
Return value of call to instance of
WeighedPredictive
.- guide_log_prob: torch.Tensor¶
- log_weights: torch.Tensor¶
- model_log_prob: torch.Tensor¶
- class EmpiricalMarginal(trace_posterior, sites=None, validate_args=None)[source]¶
Bases:
pyro.distributions.distribution.Distribution
,Callable
Marginal distribution over a single site (or multiple, provided they have the same shape) from the
TracePosterior
’s model.Note
If multiple sites are specified, they must have the same tensor shape. Samples from each site will be stacked and stored within a single tensor. See
Empirical
. To hold the marginal distribution of sites having different shapes, useMarginals
instead.- Parameters
trace_posterior (TracePosterior) – a
TracePosterior
instance representing a Monte Carlo posterior.sites (list) – optional list of sites for which we need to generate the marginal distribution.
- class Marginals(trace_posterior, sites=None, validate_args=None)[source]¶
Bases:
object
Holds the marginal distribution over one or more sites from the
TracePosterior
’s model. This is a convenience container class, which can be extended byTracePosterior
subclasses. e.g. for implementing diagnostics.- Parameters
trace_posterior (TracePosterior) – a TracePosterior instance representing a Monte Carlo posterior.
sites (list) – optional list of sites for which we need to generate the marginal distribution.
- property empirical¶
A dictionary of sites’ names and their corresponding
EmpiricalMarginal
distribution.- Type
OrderedDict
- support(flatten=False)[source]¶
Gets support of this marginal distribution.
- Parameters
flatten (bool) – A flag to decide if we want to flatten batch_shape when the marginal distribution is collected from the posterior with
num_chains > 1
. Defaults to False.- Returns
a dict with keys are sites’ names and values are sites’ supports.
- Return type
OrderedDict
- class TracePosterior(num_chains=1)[source]¶
Bases:
object
Abstract TracePosterior object from which posterior inference algorithms inherit. When run, collects a bag of execution traces from the approximate posterior. This is designed to be used by other utility classes like EmpiricalMarginal, that need access to the collected execution traces.
- information_criterion(pointwise=False)[source]¶
Computes information criterion of the model. Currently, returns only “Widely Applicable/Watanabe-Akaike Information Criterion” (WAIC) and the corresponding effective number of parameters.
Reference:
[1] Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC, Aki Vehtari, Andrew Gelman, and Jonah Gabry
- Parameters
pointwise (bool) – a flag to decide if we want to get a vectorized WAIC or not. When
pointwise=False
, returns the sum.- Returns
a dictionary containing values of WAIC and its effective number of parameters.
- Return type
OrderedDict
- class TracePredictive(model, posterior, num_samples, keep_sites=None)[source]¶
Bases:
pyro.infer.abstract_infer.TracePosterior
Warning
This class is deprecated and will be removed in a future release. Use the
Predictive
class instead.Generates and holds traces from the posterior predictive distribution, given model execution traces from the approximate posterior. This is achieved by constraining latent sites to randomly sampled parameter values from the model execution traces and running the model forward to generate traces with new response (“_RETURN”) sites. :param model: arbitrary Python callable containing Pyro primitives. :param TracePosterior posterior: trace posterior instance holding samples from the model’s approximate posterior. :param int num_samples: number of samples to generate. :param keep_sites: The sites which should be sampled from posterior distribution (default: all)