Automatic Guide Generation


class AutoGuide(model, *, create_plates=None)[source]

Bases: pyro.nn.module.PyroModule

Base class for automatic guides.

Derived classes must implement the forward() method, with the same *args, **kwargs as the base model.

Auto guides can be used individually or combined in an AutoGuideList object.

  • model (callable) – A pyro model.
  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a pyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
call(*args, **kwargs)[source]

Method that calls forward() and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlike forward(), this method can be traced by torch.jit.trace_module().


This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue <>_.

median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict

Samples an encoded latent given the same *args, **kwargs as the base model.


class AutoGuideList(model, *, create_plates=None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide, torch.nn.modules.container.ModuleList

Container class to combine multiple automatic guides.

Example usage:

guide = AutoGuideList(my_model)
guide.append(AutoDiagonalNormal(poutine.block(model, hide=["assignment"])))
guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"])))
svi = SVI(model, guide, optim, Trace_ELBO())
Parameters:model (callable) – a Pyro model

Deprecated alias for append().


Add an automatic guide for part of the model. The guide should have been created by blocking the model to restrict to a subset of sample sites. No two parts should operate on any one sample site.

Parameters:part (AutoGuide or callable) – a partial guide to add
forward(*args, **kwargs)[source]

A composite guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict


class AutoCallable(model, guide, median=<function AutoCallable.<lambda>>)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

AutoGuide wrapper for simple callable guides.

This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:

def my_local_guide(*args, **kwargs):

guide = AutoGuideList(model)
guide.add(AutoDelta(poutine.block(model, expose=['my_global_param']))
guide.add(my_local_guide)  # automatically wrapped in an AutoCallable

To specify a median callable, you can instead:

def my_local_median(*args, **kwargs)

guide.add(AutoCallable(model, my_local_guide, my_local_median))

For more complex guides that need e.g. access to plates, users should instead subclass AutoGuide.

  • model (callable) – a Pyro model
  • guide (callable) – a Pyro guide (typically over only part of the model)
  • median (callable) – an optional callable returning a dict mapping sample site name to computed median tensor.
forward(*args, **kwargs)[source]


class AutoNormal(model, *, init_loc_fn=<function init_to_feasible>, init_scale=0.1, create_plates=None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

This implementation of AutoGuide uses Normal(0, 1) distributions to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

It should be equivalent to :class: AutoDiagonalNormal , but with more convenient site names and with better support for TraceMeanField_ELBO .

In AutoDiagonalNormal , if your model has N named parameters with dimensions k_i and sum k_i = D, you get a single vector of length D for your mean, and a single vector of length D for sigmas. This guide gives you N distinct normals that you can call by name.


guide = AutoNormal(model)
svi = SVI(model, guide, ...)
  • model (callable) – A Pyro model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a pyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict
quantiles(quantiles, *args, **kwargs)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters:quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.
Returns:A dict mapping sample site name to a list of quantile values.
Return type:dict


class AutoDelta(model, init_loc_fn=<function init_to_median>, *, create_plates=None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

This implementation of AutoGuide uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.


This class does MAP inference in constrained space.


guide = AutoDelta(model)
svi = SVI(model, guide, ...)

Latent variables are initialized using init_loc_fn(). To change the default behavior, create a custom init_loc_fn() as described in Initialization , for example:

def my_init_fn(site):
    if site["name"] == "level":
        return torch.tensor([-1., 0., 1.])
    if site["name"] == "concentration":
        return torch.ones(k)
    return init_to_sample(site)
  • model (callable) – A Pyro model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a pyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict


class AutoContinuous(model, init_loc_fn=<function init_to_median>)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].

This uses torch.distributions.transforms to transform each constrained latent variable to an unconstrained space, then concatenate all variables into a single unconstrained latent variable. Each derived class implements a get_posterior() method returning a distribution over this single unconstrained latent variable.

Assumes model structure and latent dimension are fixed, and all latent variables are continuous.

Parameters:model (callable) – a Pyro model


[1] Automatic Differentiation Variational Inference,
Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
  • model (callable) – A Pyro model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict

Returns the base distribution of the posterior when reparameterized as a TransformedDistribution. This should not depend on the model’s *args, **kwargs.

posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
Returns:TorchDistribution instance representing the base distribution.
get_posterior(*args, **kwargs)[source]

Returns the posterior distribution.

get_transform(*args, **kwargs)[source]

Returns the transform applied to the base distribution when the posterior is reparameterized as a TransformedDistribution. This may depend on the model’s *args, **kwargs.

posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
Returns:a Transform instance.
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict
quantiles(quantiles, *args, **kwargs)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters:quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.
Returns:A dict mapping sample site name to a list of quantile values.
Return type:dict
sample_latent(*args, **kwargs)[source]

Samples an encoded latent given the same *args, **kwargs as the base model.


class AutoMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.


guide = AutoMultivariateNormal(model)
svi = SVI(model, guide, ...)

By default the mean vector is initialized by init_loc_fn() and the Cholesky factor is initialized to the identity times a small factor.

  • model (callable) – A generative model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
get_posterior(*args, **kwargs)[source]

Returns a MultivariateNormal posterior distribution.

get_transform(*args, **kwargs)[source]


class AutoDiagonalNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.


guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, ...)

By default the mean vector is initialized to zero and the scale is initialized to the identity times a small factor.

  • model (callable) – A generative model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
get_posterior(*args, **kwargs)[source]

Returns a diagonal Normal posterior distribution.

get_transform(*args, **kwargs)[source]


class AutoLowRankMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1, rank=None)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.


guide = AutoLowRankMultivariateNormal(model, rank=10)
svi = SVI(model, guide, ...)

By default the cov_diag is initialized to a small constant and the cov_factor is initialized randomly such that on average cov_factor.matmul(cov_factor.t()) has the same scale as cov_diag.

  • model (callable) – A generative model.
  • rank (int or None) – The rank of the low-rank part of the covariance matrix. Defaults to approximately sqrt(latent dim).
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Approximate initial scale for the standard deviation of each (unconstrained transformed) latent variable.
get_posterior(*args, **kwargs)[source]

Returns a LowRankMultivariateNormal posterior distribution.


class AutoNormalizingFlow(model, init_transform_fn)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a Diagonal Normal distribution transformed via a sequence of bijective transforms (e.g. various TransformModule subclasses) to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.


transform_init = partial(iterated, block_autoregressive,
guide = AutoNormalizingFlow(model, transform_init)
svi = SVI(model, guide, ...)
  • model (callable) – a generative model
  • init_transform_fn – a callable which when provided with the latent dimension returns an instance of Transform , or TransformModule if the transform has trainable params.
get_posterior(*args, **kwargs)[source]
get_transform(*args, **kwargs)[source]


class AutoIAFNormal(model, hidden_dim=None, init_loc_fn=None, num_transforms=1, **init_transform_kwargs)[source]

Bases: pyro.infer.autoguide.guides.AutoNormalizingFlow

This implementation of AutoContinuous uses a Diagonal Normal distribution transformed via a AffineAutoregressive to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.


guide = AutoIAFNormal(model, hidden_dim=latent_dim)
svi = SVI(model, guide, ...)
  • model (callable) – a generative model
  • hidden_dim (int) – number of hidden dimensions in the IAF
  • init_loc_fn (callable) –

    A per-site initialization function. See Initialization section for available functions.


    This argument is only to preserve backwards compatibility and has no effect in practice.

  • num_transforms (int) – number of AffineAutoregressive transforms to use in sequence.
  • init_transform_kwargs – other keyword arguments taken by affine_autoregressive().


class AutoLaplaceApproximation(model, init_loc_fn=<function init_to_median>)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

Laplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.


delta_guide = AutoLaplaceApproximation(model)
svi = SVI(model, delta_guide, ...)
# ...then train the delta_guide...
guide = delta_guide.laplace_approximation()

By default the mean vector is initialized to an empirical prior median.

  • model (callable) – a generative model
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
get_posterior(*args, **kwargs)[source]

Returns a Delta posterior distribution for MAP inference.

laplace_approximation(*args, **kwargs)[source]

Returns a AutoMultivariateNormal instance whose posterior’s loc and scale_tril are given by Laplace approximation.


class AutoDiscreteParallel(model, *, create_plates=None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

A discrete mean-field guide that learns a latent discrete distribution for each discrete site in the model.

forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict


The pyro.infer.autoguide.initialization module contains initialization functions for automatic guides.

The standard interface for initialization is a function that inputs a Pyro trace site dict and returns an appropriately sized value to serve as an initial constrained value for a guide estimate.


Initialize to an arbitrary feasible point, ignoring distribution parameters.


Initialize to a random sample from the prior.

init_to_median(site=None, num_samples=15)[source]

Initialize to the prior median; fallback to a feasible point if median is undefined.


Initialize to the prior mean; fallback to median if mean is undefined.

init_to_uniform(site=None, radius=2)[source]

Initialize to a random point in the area (-radius, radius) of unconstrained domain.

Parameters:radius (float) – specifies the range to draw an initial point in the unconstrained domain.
init_to_value(site=None, values={})[source]

Initialize to the value specified in values. We defer to init_to_uniform() strategy for sites which do not appear in values.

Parameters:values (dict) – dictionary of initial values keyed by site name.
init_to_generated(site=None, generate=<function <lambda>>)[source]

Initialize to another initialization strategy returned by the callback generate which is called once per model execution.

This is like init_to_value() but can produce different (e.g. random) values once per model execution. For example to generate values and return init_to_value you could define:

def generate():
    values = {"x": torch.randn(100), "y": torch.rand(5)}
    return init_to_value(values=values)

my_init_fn = init_to_generated(generate=generate)
Parameters:generate (callable) – A callable returning another initialization function, e.g. returning an init_to_value(values={...}) populated with a dictionary of random samples.
class InitMessenger(init_fn)[source]

Bases: pyro.poutine.messenger.Messenger

Initializes a site by replacing .sample() calls with values drawn from an initialization strategy. This is mainly for internal use by autoguide classes.

Parameters:init_fn (callable) – An initialization function.