SVI

class SVI(model, guide, optim, loss, loss_and_grads=None, **kwargs)[source]

Bases: object

Parameters:
  • model – the model (callable containing Pyro primitives)
  • guide – the guide (callable containing Pyro primitives)
  • optim (pyro.optim.PyroOptim) – a wrapper a for a PyTorch optimizer
  • loss (pyro.infer.elbo.ELBO) – an instance of a subclass of ELBO. Pyro provides three built-in losses: Trace_ELBO, Trace_ELBO, and Trace_ELBO. See the ELBO docs to learn how to implement a custom loss.

A unified interface for stochastic variational inference in Pyro. The most commonly used loss is loss=Trace_ELBO(). See the tutorial SVI Part I for a discussion.

evaluate_loss(*args, **kwargs)[source]
Returns:estimate of the loss
Return type:float

Evaluate the loss function. Any args or kwargs are passed to the model and guide.

step(*args, **kwargs)[source]
Returns:estimate of the loss
Return type:float

Take a gradient step on the loss function (and any auxiliary loss functions generated under the hood by loss_and_grads). Any args or kwargs are passed to the model and guide

ELBO

class ELBO(num_particles=1, max_iarange_nesting=inf, strict_enumeration_warning=True)[source]

Bases: object

ELBO is the top-level interface for stochastic variational inference via optimization of the evidence lower bound. Most users will not interact with ELBO directly; instead they will interact with SVI. ELBO dispatches to Trace_ELBO and TraceGraph_ELBO, where the internal implementations live.

Parameters:
  • num_particles – The number of particles/samples used to form the ELBO (gradient) estimators.
  • max_iarange_nesting (int) – Optional bound on max number of nested pyro.iarange() contexts. This is only required to enumerate over sample sites in parallel, e.g. if a site sets infer={"enumerate": "parallel"}.
  • strict_enumeration_warning (bool) – Whether to warn about possible misuse of enumeration, i.e. that pyro.infer.traceenum_elbo.TraceEnum_ELBO is used iff there are enumerated sample sites.

References

[1] Automated Variational Inference in Probabilistic Programming David Wingate, Theo Weber

[2] Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei

class Trace_ELBO(num_particles=1, max_iarange_nesting=inf, strict_enumeration_warning=True)[source]

Bases: pyro.infer.elbo.ELBO

A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and [2]. There are no restrictions on the dependency structure of the model or the guide. The gradient estimator includes partial Rao-Blackwellization for reducing the variance of the estimator when non-reparameterizable random variables are present. The Rao-Blackwellization is partial in that it only uses conditional independence information that is marked by iarange contexts. For more fine-grained Rao-Blackwellization, see TraceGraph_ELBO.

References

[1] Automated Variational Inference in Probabilistic Programming,
David Wingate, Theo Weber
[2] Black Box Variational Inference,
Rajesh Ranganath, Sean Gerrish, David M. Blei
loss(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators.

class JitTrace_ELBO(num_particles=1, max_iarange_nesting=inf, strict_enumeration_warning=True)[source]

Bases: pyro.infer.trace_elbo.Trace_ELBO

Like Trace_ELBO but uses pyro.ops.jit.compile() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.
  • Models must not depend on any global data (except the param store).
  • All model inputs that are tensors must be passed in via *args.
  • All model inputs that are not tensors must be passed in via *kwargs, and these will be fixed to their values on the first call to jit_loss_and_grads().

Warning

Experimental. Interface subject to change.

loss_and_grads(model, guide, *args, **kwargs)[source]
class TraceGraph_ELBO(num_particles=1, max_iarange_nesting=inf, strict_enumeration_warning=True)[source]

Bases: pyro.infer.elbo.ELBO

A TraceGraph implementation of ELBO-based SVI. The gradient estimator is constructed along the lines of reference [1] specialized to the case of the ELBO. It supports arbitrary dependency structure for the model and guide as well as baselines for non-reparameterizable random variables. Where possible, conditional dependency information as recorded in the Trace is used to reduce the variance of the gradient estimator. In particular three kinds of conditional dependency information are used to reduce variance: - the sequential order of samples (z is sampled after y => y does not depend on z) - iarange generators - irange generators

References

[1] Gradient Estimation Using Stochastic Computation Graphs,
John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel
[2] Neural Variational Inference and Learning in Belief Networks
Andriy Mnih, Karol Gregor
loss(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. If baselines are present, a baseline loss is also constructed and differentiated.

class JitTraceGraph_ELBO(num_particles=1, max_iarange_nesting=inf, strict_enumeration_warning=True)[source]

Bases: pyro.infer.tracegraph_elbo.TraceGraph_ELBO

Like TraceGraph_ELBO but uses torch.jit.compile() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.
  • Models must not depend on any global data (except the param store).
  • All model inputs that are tensors must be passed in via *args.
  • All model inputs that are not tensors must be passed in via *kwargs, and these will be fixed to their values on the first call to loss_and_grads().

Warning

Experimental. Interface subject to change.

loss_and_grads(model, guide, *args, **kwargs)[source]
class TraceEnum_ELBO(num_particles=1, max_iarange_nesting=inf, strict_enumeration_warning=True)[source]

Bases: pyro.infer.elbo.ELBO

A trace implementation of ELBO-based SVI that supports enumeration over discrete sample sites.

To enumerate over a sample site, the guide’s sample site must specify either infer={'enumerate': 'sequential'} or infer={'enumerate': 'parallel'}. To configure all sites at once, use config_enumerate`().

This assumes restricted dependency structure on the model and guide: variables outside of an iarange can never depend on variables inside that iarange.

loss(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Estimates the ELBO using num_particles many samples (particles).

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Estimates the ELBO using num_particles many samples (particles). Performs backward on the ELBO of each particle.

class JitTraceEnum_ELBO(num_particles=1, max_iarange_nesting=inf, strict_enumeration_warning=True)[source]

Bases: pyro.infer.traceenum_elbo.TraceEnum_ELBO

Like TraceEnum_ELBO but uses pyro.ops.jit.compile() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.
  • Models must not depend on any global data (except the param store).
  • All model inputs that are tensors must be passed in via *args.
  • All model inputs that are not tensors must be passed in via *kwargs, and these will be fixed to their values on the first call to jit_loss_and_grads().

Warning

Experimental. Interface subject to change.

loss_and_grads(model, guide, *args, **kwargs)[source]

Importance

class Importance(model, guide=None, num_samples=None)[source]

Bases: pyro.infer.abstract_infer.TracePosterior

Parameters:
  • model – probabilistic model defined as a function
  • guide – guide used for sampling defined as a function
  • num_samples – number of samples to draw from the guide (default 10)

This method performs posterior inference by importance sampling using the guide as the proposal distribution. If no guide is provided, it defaults to proposing from the model’s prior.

Inference Utilities

class EmpiricalMarginal(trace_posterior, sites=None, validate_args=None)[source]

Bases: pyro.distributions.empirical.Empirical

Marginal distribution, that wraps over a TracePosterior object to provide a a marginal over one or more latent sites or the return values of the TracePosterior’s model. If multiple sites are specified, they must have the same tensor shape.

Parameters:
  • trace_posterior (TracePosterior) – a TracePosterior instance representing a Monte Carlo posterior.
  • sites (list) – optional list of sites for which we need to generate the marginal distribution. Note that for multiple sites, the shape for the site values must match (needed by the underlying Empirical class).
class TracePosterior[source]

Bases: object

Abstract TracePosterior object from which posterior inference algorithms inherit. When run, collects a bag of execution traces from the approximate posterior. This is designed to be used by other utility classes like EmpiricalMarginal, that need access to the collected execution traces.

run(*args, **kwargs)[source]

Calls self._traces to populate execution traces from a stochastic Pyro model.

Parameters:
  • args – optional args taken by self._traces.
  • kwargs – optional keywords args taken by self._traces.
class TracePredictive(model, posterior, num_samples)[source]

Bases: pyro.infer.abstract_infer.TracePosterior

Generates and holds traces from the posterior predictive distribution, given model execution traces from the approximate posterior. This is achieved by constraining latent sites to randomly sampled parameter values from the model execution traces and running the model forward to generate traces with new response (“_RETURN”) sites.

Parameters:
  • model – arbitrary Python callable containing Pyro primitives.
  • posterior (TracePosterior) – trace posterior instance holding samples from the model’s approximate posterior.
  • num_samples (int) – number of samples to generate.