Automatic Guide Generation¶
AutoGuide¶
-
class
AutoGuide(model)[source]¶ Bases:
pyro.nn.module.PyroModuleBase class for automatic guides.
Derived classes must implement the
forward()method, with the same*args, **kwargsas the basemodel.Auto guides can be used individually or combined in an
AutoGuideListobject.Parameters: model (callable) – a pyro model -
call(*args, **kwargs)[source]¶ Method that calls
forward()and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlikeforward(), this method can be traced bytorch.jit.trace_module().Warning
This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue <https://github.com/pytorch/pytorch/issues/27743>_.
-
median(*args, **kwargs)[source]¶ Returns the posterior median value of each latent variable.
Returns: A dict mapping sample site name to median tensor. Return type: dict
-
model¶
-
AutoGuideList¶
-
class
AutoGuideList(model)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide,torch.nn.modules.container.ModuleListContainer class to combine multiple automatic guides.
Example usage:
guide = AutoGuideList(my_model) guide.add(AutoDiagonalNormal(poutine.block(model, hide=["assignment"]))) guide.add(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) svi = SVI(model, guide, optim, Trace_ELBO())
Parameters: model (callable) – a Pyro model -
append(part)[source]¶ Add an automatic guide for part of the model. The guide should have been created by blocking the model to restrict to a subset of sample sites. No two parts should operate on any one sample site.
Parameters: part (AutoGuide or callable) – a partial guide to add
-
AutoCallable¶
-
class
AutoCallable(model, guide, median=<function AutoCallable.<lambda>>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuideAutoGuidewrapper for simple callable guides.This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:
def my_local_guide(*args, **kwargs): ... guide = AutoGuideList(model) guide.add(AutoDelta(poutine.block(model, expose=['my_global_param'])) guide.add(my_local_guide) # automatically wrapped in an AutoCallable
To specify a median callable, you can instead:
def my_local_median(*args, **kwargs) ... guide.add(AutoCallable(model, my_local_guide, my_local_median))
For more complex guides that need e.g. access to plates, users should instead subclass
AutoGuide.Parameters: - model (callable) – a Pyro model
- guide (callable) – a Pyro guide (typically over only part of the model)
- median (callable) – an optional callable returning a dict mapping sample site name to computed median tensor.
AutoDelta¶
-
class
AutoDelta(model, init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuideThis implementation of
AutoGuideuses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs.Note
This class does MAP inference in constrained space.
Usage:
guide = AutoDelta(model) svi = SVI(model, guide, ...)
Latent variables are initialized using
init_loc_fn(). To change the default behavior, create a custominit_loc_fn()as described in Initialization , for example:def my_init_fn(site): if site["name"] == "level": return torch.tensor([-1., 0., 1.]) if site["name"] == "concentration": return torch.ones(k) return init_to_sample(site)
Parameters: - model (callable) – A Pyro model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
AutoContinuous¶
-
class
AutoContinuous(model, init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuideBase class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].
This uses
torch.distributions.transformsto transform each constrained latent variable to an unconstrained space, then concatenate all variables into a single unconstrained latent variable. Each derived class implements aget_posterior()method returning a distribution over this single unconstrained latent variable.Assumes model structure and latent dimension are fixed, and all latent variables are continuous.
Parameters: model (callable) – a Pyro model Reference:
- [1] Automatic Differentiation Variational Inference,
- Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
Parameters: - model (callable) – A Pyro model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
-
forward(*args, **kwargs)[source]¶ An automatic guide with the same
*args, **kwargsas the basemodel.Returns: A dict mapping sample site name to sampled value. Return type: dict
-
median(*args, **kwargs)[source]¶ Returns the posterior median value of each latent variable.
Returns: A dict mapping sample site name to median tensor. Return type: dict
-
quantiles(quantiles, *args, **kwargs)[source]¶ Returns posterior quantiles each latent variable. Example:
print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters: quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1. Returns: A dict mapping sample site name to a list of quantile values. Return type: dict
AutoMultivariateNormal¶
-
class
AutoMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuousThis implementation of
AutoContinuoususes a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs.Usage:
guide = AutoMultivariateNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized by
init_loc_fn()and the Cholesky factor is initialized to the identity times a small factor.Parameters: - model (callable) – A generative model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
AutoDiagonalNormal¶
-
class
AutoDiagonalNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuousThis implementation of
AutoContinuoususes a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs.Usage:
guide = AutoDiagonalNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized to zero and the scale is initialized to the identity times a small factor.
Parameters: - model (callable) – A generative model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
AutoLowRankMultivariateNormal¶
-
class
AutoLowRankMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1, rank=None)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuousThis implementation of
AutoContinuoususes a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs.Usage:
guide = AutoLowRankMultivariateNormal(model, rank=10) svi = SVI(model, guide, ...)
By default the
cov_diagis initialized to a small constant and thecov_factoris initialized randomly such that on averagecov_factor.matmul(cov_factor.t())has the same scale ascov_diag.Parameters: - model (callable) – A generative model.
- rank (int or None) – The rank of the low-rank part of the covariance matrix.
Defaults to approximately
sqrt(latent dim). - init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- init_scale (float) – Approximate initial scale for the standard deviation of each (unconstrained transformed) latent variable.
AutoIAFNormal¶
-
class
AutoIAFNormal(model, hidden_dim=None, init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuousThis implementation of
AutoContinuoususes a Diagonal Normal distribution transformed via aAffineAutoregressiveto construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs.Usage:
guide = AutoIAFNormal(model, hidden_dim=latent_dim) svi = SVI(model, guide, ...)
Parameters: - model (callable) – a generative model
- hidden_dim (int) – number of hidden dimensions in the IAF
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
AutoLaplaceApproximation¶
-
class
AutoLaplaceApproximation(model, init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuousLaplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.
Usage:
delta_guide = AutoLaplaceApproximation(model) svi = SVI(model, delta_guide, ...) # ...then train the delta_guide... guide = delta_guide.laplace_approximation()
By default the mean vector is initialized to an empirical prior median.
Parameters: - model (callable) – a generative model
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
-
laplace_approximation(*args, **kwargs)[source]¶ Returns a
AutoMultivariateNormalinstance whose posterior’s loc and scale_tril are given by Laplace approximation.
AutoDiscreteParallel¶
Initialization¶
The pyro.infer.autoguide.initialization module contains initialization functions for automatic guides.
The standard interface for initialization is a function that inputs a Pyro
trace site dict and returns an appropriately sized value to serve
as an initial constrained value for a guide estimate.
-
init_to_feasible(site)[source]¶ Initialize to an arbitrary feasible point, ignoring distribution parameters.
-
init_to_median(site, num_samples=15)[source]¶ Initialize to the prior median; fallback to a feasible point if median is undefined.
-
class
InitMessenger(init_fn)[source]¶ Bases:
pyro.poutine.messenger.MessengerInitializes a site by replacing
.sample()calls with values drawn from an initialization strategy. This is mainly for internal use by autoguide classes.Parameters: init_fn (callable) – An initialization function.