Automatic Guide Generation

AutoGuide

class AutoGuide(model, *, create_plates=None)[source]

Bases: pyro.nn.module.PyroModule

Base class for automatic guides.

Derived classes must implement the forward() method, with the same *args, **kwargs as the base model.

Auto guides can be used individually or combined in an AutoGuideList object.

Parameters:
  • model (callable) – A pyro model.
  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a pyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
model
call(*args, **kwargs)[source]

Method that calls forward() and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlike forward(), this method can be traced by torch.jit.trace_module().

Warning

This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue <https://github.com/pytorch/pytorch/issues/27743>_.

sample_latent(**kwargs)[source]

Samples an encoded latent given the same *args, **kwargs as the base model.

median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict

AutoGuideList

class AutoGuideList(model, *, create_plates=None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide, torch.nn.modules.container.ModuleList

Container class to combine multiple automatic guides.

Example usage:

guide = AutoGuideList(my_model)
guide.append(AutoDiagonalNormal(poutine.block(model, hide=["assignment"])))
guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"])))
svi = SVI(model, guide, optim, Trace_ELBO())
Parameters:model (callable) – a Pyro model
append(part)[source]

Add an automatic guide for part of the model. The guide should have been created by blocking the model to restrict to a subset of sample sites. No two parts should operate on any one sample site.

Parameters:part (AutoGuide or callable) – a partial guide to add
add(part)[source]

Deprecated alias for append().

forward(*args, **kwargs)[source]

A composite guide with the same *args, **kwargs as the base model.

Note

This method is used internally by Module. Users should instead use __call__().

Returns:A dict mapping sample site name to sampled value.
Return type:dict
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict
quantiles(quantiles, *args, **kwargs)[source]

Returns the posterior quantile values of each latent variable.

Parameters:quantiles (list) – A list of requested quantiles between 0 and 1.
Returns:A dict mapping sample site name to quantiles tensor.
Return type:dict

AutoCallable

class AutoCallable(model, guide, median=<function AutoCallable.<lambda>>)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

AutoGuide wrapper for simple callable guides.

This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:

def my_local_guide(*args, **kwargs):
    ...

guide = AutoGuideList(model)
guide.add(AutoDelta(poutine.block(model, expose=['my_global_param']))
guide.add(my_local_guide)  # automatically wrapped in an AutoCallable

To specify a median callable, you can instead:

def my_local_median(*args, **kwargs)
    ...

guide.add(AutoCallable(model, my_local_guide, my_local_median))

For more complex guides that need e.g. access to plates, users should instead subclass AutoGuide.

Parameters:
  • model (callable) – a Pyro model
  • guide (callable) – a Pyro guide (typically over only part of the model)
  • median (callable) – an optional callable returning a dict mapping sample site name to computed median tensor.
forward(*args, **kwargs)[source]

AutoNormal

class AutoNormal(model, *, init_loc_fn=<function init_to_feasible>, init_scale=0.1, create_plates=None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

This implementation of AutoGuide uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

It should be equivalent to :class: AutoDiagonalNormal , but with more convenient site names and with better support for TraceMeanField_ELBO .

In AutoDiagonalNormal , if your model has N named parameters with dimensions k_i and sum k_i = D, you get a single vector of length D for your mean, and a single vector of length D for sigmas. This guide gives you N distinct normals that you can call by name.

Usage:

guide = AutoNormal(model)
svi = SVI(model, guide, ...)
Parameters:
  • model (callable) – A Pyro model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a pyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
scale_constraint = SoftplusPositive(lower_bound=0.0)
forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Note

This method is used internally by Module. Users should instead use __call__().

Returns:A dict mapping sample site name to sampled value.
Return type:dict
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict
quantiles(quantiles, *args, **kwargs)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters:quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.
Returns:A dict mapping sample site name to a tensor of quantile values.
Return type:dict

AutoDelta

class AutoDelta(model, init_loc_fn=<function init_to_median>, *, create_plates=None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

This implementation of AutoGuide uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Note

This class does MAP inference in constrained space.

Usage:

guide = AutoDelta(model)
svi = SVI(model, guide, ...)

Latent variables are initialized using init_loc_fn(). To change the default behavior, create a custom init_loc_fn() as described in Initialization , for example:

def my_init_fn(site):
    if site["name"] == "level":
        return torch.tensor([-1., 0., 1.])
    if site["name"] == "concentration":
        return torch.ones(k)
    return init_to_sample(site)
Parameters:
  • model (callable) – A Pyro model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a pyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Note

This method is used internally by Module. Users should instead use __call__().

Returns:A dict mapping sample site name to sampled value.
Return type:dict
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict

AutoContinuous

class AutoContinuous(model, init_loc_fn=<function init_to_median>)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].

This uses torch.distributions.transforms to transform each constrained latent variable to an unconstrained space, then concatenate all variables into a single unconstrained latent variable. Each derived class implements a get_posterior() method returning a distribution over this single unconstrained latent variable.

Assumes model structure and latent dimension are fixed, and all latent variables are continuous.

Parameters:model (callable) – a Pyro model

Reference:

[1] Automatic Differentiation Variational Inference,
Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
Parameters:
  • model (callable) – A Pyro model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
get_base_dist()[source]

Returns the base distribution of the posterior when reparameterized as a TransformedDistribution. This should not depend on the model’s *args, **kwargs.

posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
Returns:TorchDistribution instance representing the base distribution.
get_transform(*args, **kwargs)[source]

Returns the transform applied to the base distribution when the posterior is reparameterized as a TransformedDistribution. This may depend on the model’s *args, **kwargs.

posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
Returns:a Transform instance.
get_posterior(*args, **kwargs)[source]

Returns the posterior distribution.

sample_latent(*args, **kwargs)[source]

Samples an encoded latent given the same *args, **kwargs as the base model.

forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Note

This method is used internally by Module. Users should instead use __call__().

Returns:A dict mapping sample site name to sampled value.
Return type:dict
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict
quantiles(quantiles, *args, **kwargs)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters:quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.
Returns:A dict mapping sample site name to a tensor of quantile values.
Return type:dict

AutoMultivariateNormal

class AutoMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoMultivariateNormal(model)
svi = SVI(model, guide, ...)

By default the mean vector is initialized by init_loc_fn() and the Cholesky factor is initialized to the identity times a small factor.

Parameters:
  • model (callable) – A generative model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
scale_tril_constraint = SoftplusLowerCholesky()
get_base_dist()[source]
get_transform(*args, **kwargs)[source]
get_posterior(*args, **kwargs)[source]

Returns a MultivariateNormal posterior distribution.

AutoDiagonalNormal

class AutoDiagonalNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, ...)

By default the mean vector is initialized to zero and the scale is initialized to the identity times a small factor.

Parameters:
  • model (callable) – A generative model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
scale_constraint = SoftplusPositive(lower_bound=0.0)
get_base_dist()[source]
get_transform(*args, **kwargs)[source]
get_posterior(*args, **kwargs)[source]

Returns a diagonal Normal posterior distribution.

AutoLowRankMultivariateNormal

class AutoLowRankMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1, rank=None)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoLowRankMultivariateNormal(model, rank=10)
svi = SVI(model, guide, ...)

By default the cov_diag is initialized to a small constant and the cov_factor is initialized randomly such that on average cov_factor.matmul(cov_factor.t()) has the same scale as cov_diag.

Parameters:
  • model (callable) – A generative model.
  • rank (int or None) – The rank of the low-rank part of the covariance matrix. Defaults to approximately sqrt(latent dim).
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Approximate initial scale for the standard deviation of each (unconstrained transformed) latent variable.
scale_constraint = SoftplusPositive(lower_bound=0.0)
get_posterior(*args, **kwargs)[source]

Returns a LowRankMultivariateNormal posterior distribution.

AutoNormalizingFlow

class AutoNormalizingFlow(model, init_transform_fn)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a Diagonal Normal distribution transformed via a sequence of bijective transforms (e.g. various TransformModule subclasses) to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

transform_init = partial(iterated, block_autoregressive,
                         repeats=2)
guide = AutoNormalizingFlow(model, transform_init)
svi = SVI(model, guide, ...)
Parameters:
  • model (callable) – a generative model
  • init_transform_fn – a callable which when provided with the latent dimension returns an instance of Transform , or TransformModule if the transform has trainable params.
get_base_dist()[source]
get_transform(*args, **kwargs)[source]
get_posterior(*args, **kwargs)[source]

AutoIAFNormal

class AutoIAFNormal(model, hidden_dim=None, init_loc_fn=None, num_transforms=1, **init_transform_kwargs)[source]

Bases: pyro.infer.autoguide.guides.AutoNormalizingFlow

This implementation of AutoContinuous uses a Diagonal Normal distribution transformed via a AffineAutoregressive to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoIAFNormal(model, hidden_dim=latent_dim)
svi = SVI(model, guide, ...)
Parameters:
  • model (callable) – a generative model
  • hidden_dim (list[int]) – number of hidden dimensions in the IAF
  • init_loc_fn (callable) –

    A per-site initialization function. See Initialization section for available functions.

    Warning

    This argument is only to preserve backwards compatibility and has no effect in practice.

  • num_transforms (int) – number of AffineAutoregressive transforms to use in sequence.
  • init_transform_kwargs – other keyword arguments taken by affine_autoregressive().

AutoLaplaceApproximation

class AutoLaplaceApproximation(model, init_loc_fn=<function init_to_median>)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

Laplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.

Usage:

delta_guide = AutoLaplaceApproximation(model)
svi = SVI(model, delta_guide, ...)
# ...then train the delta_guide...
guide = delta_guide.laplace_approximation()

By default the mean vector is initialized to an empirical prior median.

Parameters:
  • model (callable) – a generative model
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
get_posterior(*args, **kwargs)[source]

Returns a Delta posterior distribution for MAP inference.

laplace_approximation(*args, **kwargs)[source]

Returns a AutoMultivariateNormal instance whose posterior’s loc and scale_tril are given by Laplace approximation.

AutoDiscreteParallel

class AutoDiscreteParallel(model, *, create_plates=None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

A discrete mean-field guide that learns a latent discrete distribution for each discrete site in the model.

forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Note

This method is used internally by Module. Users should instead use __call__().

Returns:A dict mapping sample site name to sampled value.
Return type:dict

AutoStructured

class AutoStructured(model, *, conditionals: Union[str, Dict[str, Union[str, Callable]]] = 'mvn', dependencies: Union[str, Dict[str, Dict[str, Union[str, Callable]]]] = 'linear', init_loc_fn: Callable = <function init_to_feasible>, init_scale: float = 0.1, create_plates: Optional[Callable] = None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

Structured guide whose conditional distributions are Delta, Normal, MultivariateNormal, or by a callable, and whose latent variables can depend on each other either linearly (in unconstrained space) or via shearing by a callable.

Usage:

def model(data):
    x = pyro.sample("x", dist.LogNormal(0, 1))
    with pyro.plate("plate", len(data)):
        y = pyro.sample("y", dist.Normal(0, 1))
        pyro.sample("z", dist.Normal(y, x), obs=data)

# Either fully automatic...
guide = AutoStructured(model)

# ...or with specified conditional and dependency types...
guide = AutoStructured(
    model, conditionals="normal", dependencies="linear"
)

# ...or with custom dependency structure and distribution types.
guide = AutoStructured(
    model=model,
    conditionals={"x": "normal", "y": "delta"},
    dependencies={"x": {"y": "linear"}},
)

Once trained, this guide can be used with StructuredReparam to precondition a model for use in HMC and NUTS inference.

Note

If you declare a dependency of a high-dimensional downstream variable on a low-dimensional upstream variable, you may want to use a lower learning rate for that weight, e.g.:

def optim_config(param_name):
    config = {"lr": 0.01}
    if "deps.my_downstream.my_upstream" in param_name:
        config["lr"] *= 0.1
    return config

adam = pyro.optim.Adam(optim_config)
Parameters:
  • model (callable) – A Pyro model.
  • conditionals – Either a single distribution type or a dict mapping each latent variable name to a distribution type. A distribution type is either a string in {“delta”, “normal”, “mvn”} or a callable that returns a sample from a zero mean (or approximately centered) noise distribution (such callables typically call pyro.param() and pyro.sample() internally).
  • dependencies – Dependency type, or a dict mapping each site name to a dict mapping its upstream dependencies to dependency types. If only a dependecy type is provided, dependency structure will be inferred. A dependency type is either the string “linear” or a callable that maps a flattened upstream perturbation to flattened downstream perturbation. The string “linear” is equivalent to nn.Linear(upstream.numel(), downstream.numel(), bias=False). Dependencies must not contain cycles or self-loops.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a pyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
scale_constraint = SoftplusPositive(lower_bound=0.0)
scale_tril_constraint = SoftplusLowerCholesky()
get_deltas
forward(*args, **kwargs)[source]
median(*args, **kwargs)[source]

AutoGaussian

class AutoGaussian(model: Callable, *, init_loc_fn: Callable = <function init_to_feasible>, init_scale: float = 0.1, backend: Optional[str] = None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

Gaussian guide with optimal conditional independence structure.

This is equivalent to a full rank AutoMultivariateNormal guide, but with a sparse precision matrix determined by dependencies and plates in the model [1]. Depending on model structure, this can have asymptotically better statistical efficiency than AutoMultivariateNormal .

This guide implements multiple backends for computation. All backends use the same statistically optimal parametrization. The default “dense” backend has computational complexity similar to AutoMultivariateNormal . The experimental “funsor” backend can be asymptotically cheaper in terms of time and space (using Gaussian tensor variable elimination [2,3]), but incurs large constant overhead. The “funsor” backend requires funsor which can be installed via pip install pyro-ppl[funsor].

The guide currently does not depend on the model’s *args, **kwargs.

Example:

guide = AutoGaussian(model)
svi = SVI(model, guide, ...)

Example using experimental funsor backend:

!pip install pyro-ppl[funsor]
guide = AutoGaussian(model, backend="funsor")
svi = SVI(model, guide, ...)

References

[1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)
“Faithful inversion of generative models for effective amortized inference” https://dl.acm.org/doi/10.5555/3327144.3327229
[2] F.Obermeyer, E.Bingham, M.Jankowiak, J.Chiu, N.Pradhan, A.M.Rush, N.Goodman
(2019) “Tensor Variable Elimination for Plated Factor Graphs” http://proceedings.mlr.press/v97/obermeyer19a/obermeyer19a.pdf
[3] F. Obermeyer, E. Bingham, M. Jankowiak, D. Phan, J. P. Chen
(2019) “Functional Tensors for Probabilistic Programming” https://arxiv.org/abs/1910.10775
Parameters:
  • model (callable) – A Pyro model.
  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
  • backend (str) – Back end for performing Gaussian tensor variable elimination. Defaults to “dense”; other options include “funsor”.
scale_constraint = SoftplusPositive(lower_bound=0.0)
forward(*args, **kwargs) → Dict[str, torch.Tensor][source]
median(*args, **kwargs) → Dict[str, torch.Tensor][source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict

Initialization

The pyro.infer.autoguide.initialization module contains initialization functions for automatic guides.

The standard interface for initialization is a function that inputs a Pyro trace site dict and returns an appropriately sized value to serve as an initial constrained value for a guide estimate.

init_to_feasible(site=None)[source]

Initialize to an arbitrary feasible point, ignoring distribution parameters.

init_to_sample(site=None)[source]

Initialize to a random sample from the prior.

init_to_median(site=None, num_samples=15, *, fallback: Optional[Callable] = <function init_to_feasible>)[source]

Initialize to the prior median; fallback to fallback (defaults to init_to_feasible()) if mean is undefined.

Parameters:fallback (callable) – Fallback init strategy, for sites not specified in values.
Raises:ValueError – If fallback=None and no value for a site is given in values.
init_to_mean(site=None, *, fallback: Optional[Callable] = <function init_to_median>)[source]

Initialize to the prior mean; fallback to fallback (defaults to init_to_median()) if mean is undefined.

Parameters:fallback (callable) – Fallback init strategy, for sites not specified in values.
Raises:ValueError – If fallback=None and no value for a site is given in values.
init_to_uniform(site: Optional[dict] = None, radius: float = 2.0)[source]

Initialize to a random point in the area (-radius, radius) of unconstrained domain.

Parameters:radius (float) – specifies the range to draw an initial point in the unconstrained domain.
init_to_value(site: Optional[dict] = None, values: dict = {}, *, fallback: Optional[Callable] = <function init_to_uniform>)[source]

Initialize to the value specified in values. Fallback to fallback (defaults to init_to_uniform()) strategy for sites not appearing in values.

Parameters:
  • values (dict) – dictionary of initial values keyed by site name.
  • fallback (callable) – Fallback init strategy, for sites not specified in values.
Raises:

ValueError – If fallback=None and no value for a site is given in values.

init_to_generated(site=None, generate=<function <lambda>>)[source]

Initialize to another initialization strategy returned by the callback generate which is called once per model execution.

This is like init_to_value() but can produce different (e.g. random) values once per model execution. For example to generate values and return init_to_value you could define:

def generate():
    values = {"x": torch.randn(100), "y": torch.rand(5)}
    return init_to_value(values=values)

my_init_fn = init_to_generated(generate=generate)
Parameters:generate (callable) – A callable returning another initialization function, e.g. returning an init_to_value(values={...}) populated with a dictionary of random samples.
class InitMessenger(init_fn)[source]

Bases: pyro.poutine.messenger.Messenger

Initializes a site by replacing .sample() calls with values drawn from an initialization strategy. This is mainly for internal use by autoguide classes.

Parameters:init_fn (callable) – An initialization function.