Automatic Guide Generation

AutoGuide

class AutoGuide(model)[source]

Bases: pyro.nn.module.PyroModule

Base class for automatic guides.

Derived classes must implement the forward() method, with the same *args, **kwargs as the base model.

Auto guides can be used individually or combined in an AutoGuideList object.

Parameters:model (callable) – a pyro model
call(*args, **kwargs)[source]

Method that calls forward() and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlike forward(), this method can be traced by torch.jit.trace_module().

Warning

This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue <https://github.com/pytorch/pytorch/issues/27743>_.

median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict
model
sample_latent(**kwargs)[source]

Samples an encoded latent given the same *args, **kwargs as the base model.

AutoGuideList

class AutoGuideList(model)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide, torch.nn.modules.container.ModuleList

Container class to combine multiple automatic guides.

Example usage:

guide = AutoGuideList(my_model)
guide.add(AutoDiagonalNormal(poutine.block(model, hide=["assignment"])))
guide.add(AutoDiscreteParallel(poutine.block(model, expose=["assignment"])))
svi = SVI(model, guide, optim, Trace_ELBO())
Parameters:model (callable) – a Pyro model
add(part)[source]

Deprecated alias for append().

append(part)[source]

Add an automatic guide for part of the model. The guide should have been created by blocking the model to restrict to a subset of sample sites. No two parts should operate on any one sample site.

Parameters:part (AutoGuide or callable) – a partial guide to add
forward(*args, **kwargs)[source]

A composite guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict

AutoCallable

class AutoCallable(model, guide, median=<function AutoCallable.<lambda>>)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

AutoGuide wrapper for simple callable guides.

This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:

def my_local_guide(*args, **kwargs):
    ...

guide = AutoGuideList(model)
guide.add(AutoDelta(poutine.block(model, expose=['my_global_param']))
guide.add(my_local_guide)  # automatically wrapped in an AutoCallable

To specify a median callable, you can instead:

def my_local_median(*args, **kwargs)
    ...

guide.add(AutoCallable(model, my_local_guide, my_local_median))

For more complex guides that need e.g. access to plates, users should instead subclass AutoGuide.

Parameters:
  • model (callable) – a Pyro model
  • guide (callable) – a Pyro guide (typically over only part of the model)
  • median (callable) – an optional callable returning a dict mapping sample site name to computed median tensor.
forward(*args, **kwargs)[source]

AutoNormal

class AutoNormal(model, init_loc_fn=<function init_to_feasible>, init_scale=0.1)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

This implementation of AutoGuide uses Normal(0, 1) distributions to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

It should be equivalent to :class: AutoDiagonalNormal , but with more convenient site names and with better support for TraceMeanField_ELBO .

In AutoDiagonalNormal , if your model has N named parameters with dimensions k_i and sum k_i = D, you get a single vector of length D for your mean, and a single vector of length D for sigmas. This guide gives you N distinct normals that you can call by name.

Usage:

guide = AutoNormal(model)
svi = SVI(model, guide, ...)
forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict
quantiles(quantiles, *args, **kwargs)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters:quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.
Returns:A dict mapping sample site name to a list of quantile values.
Return type:dict

AutoDelta

class AutoDelta(model, init_loc_fn=<function init_to_median>)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

This implementation of AutoGuide uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Note

This class does MAP inference in constrained space.

Usage:

guide = AutoDelta(model)
svi = SVI(model, guide, ...)

Latent variables are initialized using init_loc_fn(). To change the default behavior, create a custom init_loc_fn() as described in Initialization , for example:

def my_init_fn(site):
    if site["name"] == "level":
        return torch.tensor([-1., 0., 1.])
    if site["name"] == "concentration":
        return torch.ones(k)
    return init_to_sample(site)
Parameters:
  • model (callable) – A Pyro model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict

AutoContinuous

class AutoContinuous(model, init_loc_fn=<function init_to_median>)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].

This uses torch.distributions.transforms to transform each constrained latent variable to an unconstrained space, then concatenate all variables into a single unconstrained latent variable. Each derived class implements a get_posterior() method returning a distribution over this single unconstrained latent variable.

Assumes model structure and latent dimension are fixed, and all latent variables are continuous.

Parameters:model (callable) – a Pyro model

Reference:

[1] Automatic Differentiation Variational Inference,
Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
Parameters:
  • model (callable) – A Pyro model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict
get_base_dist()[source]

Returns the base distribution of the posterior when reparameterized as a TransformedDistribution. This should not depend on the model’s *args, **kwargs.

posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
Returns:TorchDistribution instance representing the base distribution.
get_posterior(*args, **kwargs)[source]

Returns the posterior distribution.

get_transform(*args, **kwargs)[source]

Returns the transform applied to the base distribution when the posterior is reparameterized as a TransformedDistribution. This may depend on the model’s *args, **kwargs.

posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
Returns:a Transform instance.
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict
quantiles(quantiles, *args, **kwargs)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters:quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.
Returns:A dict mapping sample site name to a list of quantile values.
Return type:dict
sample_latent(*args, **kwargs)[source]

Samples an encoded latent given the same *args, **kwargs as the base model.

AutoMultivariateNormal

class AutoMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoMultivariateNormal(model)
svi = SVI(model, guide, ...)

By default the mean vector is initialized by init_loc_fn() and the Cholesky factor is initialized to the identity times a small factor.

Parameters:
  • model (callable) – A generative model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
get_base_dist()[source]
get_posterior(*args, **kwargs)[source]

Returns a MultivariateNormal posterior distribution.

get_transform(*args, **kwargs)[source]

AutoDiagonalNormal

class AutoDiagonalNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, ...)

By default the mean vector is initialized to zero and the scale is initialized to the identity times a small factor.

Parameters:
  • model (callable) – A generative model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
get_base_dist()[source]
get_posterior(*args, **kwargs)[source]

Returns a diagonal Normal posterior distribution.

get_transform(*args, **kwargs)[source]

AutoLowRankMultivariateNormal

class AutoLowRankMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1, rank=None)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoLowRankMultivariateNormal(model, rank=10)
svi = SVI(model, guide, ...)

By default the cov_diag is initialized to a small constant and the cov_factor is initialized randomly such that on average cov_factor.matmul(cov_factor.t()) has the same scale as cov_diag.

Parameters:
  • model (callable) – A generative model.
  • rank (int or None) – The rank of the low-rank part of the covariance matrix. Defaults to approximately sqrt(latent dim).
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Approximate initial scale for the standard deviation of each (unconstrained transformed) latent variable.
get_posterior(*args, **kwargs)[source]

Returns a LowRankMultivariateNormal posterior distribution.

AutoNormalizingFlow

class AutoNormalizingFlow(model, init_transform_fn)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a Diagonal Normal distribution transformed via a sequence of bijective transforms (e.g. various TransformModule subclasses) to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

transform_init = partial(iterated, block_autoregressive,
                         repeats=2)
guide = AutoNormalizingFlow(model, transform_init)
svi = SVI(model, guide, ...)
Parameters:
  • model (callable) – a generative model
  • init_transform_fn – a callable which when provided with the latent dimension returns an instance of Transform , or TransformModule if the transform has trainable params.
get_base_dist()[source]
get_posterior(*args, **kwargs)[source]
get_transform(*args, **kwargs)[source]

AutoIAFNormal

class AutoIAFNormal(model, hidden_dim=None, init_loc_fn=None, num_transforms=1, **init_transform_kwargs)[source]

Bases: pyro.infer.autoguide.guides.AutoNormalizingFlow

This implementation of AutoContinuous uses a Diagonal Normal distribution transformed via a AffineAutoregressive to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoIAFNormal(model, hidden_dim=latent_dim)
svi = SVI(model, guide, ...)
Parameters:
  • model (callable) – a generative model
  • hidden_dim (int) – number of hidden dimensions in the IAF
  • init_loc_fn (callable) –

    A per-site initialization function. See Initialization section for available functions.

    Warning

    This argument is only to preserve backwards compatibility and has no effect in practice.

  • num_transforms (int) – number of AffineAutoregressive transforms to use in sequence.
  • init_transform_kwargs – other keyword arguments taken by affine_autoregressive().

AutoLaplaceApproximation

class AutoLaplaceApproximation(model, init_loc_fn=<function init_to_median>)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

Laplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.

Usage:

delta_guide = AutoLaplaceApproximation(model)
svi = SVI(model, delta_guide, ...)
# ...then train the delta_guide...
guide = delta_guide.laplace_approximation()

By default the mean vector is initialized to an empirical prior median.

Parameters:
  • model (callable) – a generative model
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
get_posterior(*args, **kwargs)[source]

Returns a Delta posterior distribution for MAP inference.

laplace_approximation(*args, **kwargs)[source]

Returns a AutoMultivariateNormal instance whose posterior’s loc and scale_tril are given by Laplace approximation.

AutoDiscreteParallel

class AutoDiscreteParallel(model)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

A discrete mean-field guide that learns a latent discrete distribution for each discrete site in the model.

forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict

Initialization

The pyro.infer.autoguide.initialization module contains initialization functions for automatic guides.

The standard interface for initialization is a function that inputs a Pyro trace site dict and returns an appropriately sized value to serve as an initial constrained value for a guide estimate.

init_to_feasible(site)[source]

Initialize to an arbitrary feasible point, ignoring distribution parameters.

init_to_sample(site)[source]

Initialize to a random sample from the prior.

init_to_median(site, num_samples=15)[source]

Initialize to the prior median; fallback to a feasible point if median is undefined.

init_to_mean(site)[source]

Initialize to the prior mean; fallback to median if mean is undefined.

class InitMessenger(init_fn)[source]

Bases: pyro.poutine.messenger.Messenger

Initializes a site by replacing .sample() calls with values drawn from an initialization strategy. This is mainly for internal use by autoguide classes.

Parameters:init_fn (callable) – An initialization function.