Automatic Guide Generation¶
AutoGuide¶
-
class
AutoGuide
(model, *, create_plates=None)[source]¶ Bases:
pyro.nn.module.PyroModule
Base class for automatic guides.
Derived classes must implement the
forward()
method, with the same*args, **kwargs
as the basemodel
.Auto guides can be used individually or combined in an
AutoGuideList
object.Parameters: - model (callable) – A pyro model.
- create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning apyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
-
call
(*args, **kwargs)[source]¶ Method that calls
forward()
and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlikeforward()
, this method can be traced bytorch.jit.trace_module()
.Warning
This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue <https://github.com/pytorch/pytorch/issues/27743>_.
-
median
(*args, **kwargs)[source]¶ Returns the posterior median value of each latent variable.
Returns: A dict mapping sample site name to median tensor. Return type: dict
-
model
¶
AutoGuideList¶
-
class
AutoGuideList
(model, *, create_plates=None)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
,torch.nn.modules.container.ModuleList
Container class to combine multiple automatic guides.
Example usage:
guide = AutoGuideList(my_model) guide.append(AutoDiagonalNormal(poutine.block(model, hide=["assignment"]))) guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) svi = SVI(model, guide, optim, Trace_ELBO())
Parameters: model (callable) – a Pyro model -
append
(part)[source]¶ Add an automatic guide for part of the model. The guide should have been created by blocking the model to restrict to a subset of sample sites. No two parts should operate on any one sample site.
Parameters: part (AutoGuide or callable) – a partial guide to add
-
forward
(*args, **kwargs)[source]¶ A composite guide with the same
*args, **kwargs
as the basemodel
.Note
This method is used internally by
Module
. Users should instead use__call__()
.Returns: A dict mapping sample site name to sampled value. Return type: dict
-
AutoCallable¶
-
class
AutoCallable
(model, guide, median=<function AutoCallable.<lambda>>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
AutoGuide
wrapper for simple callable guides.This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:
def my_local_guide(*args, **kwargs): ... guide = AutoGuideList(model) guide.add(AutoDelta(poutine.block(model, expose=['my_global_param'])) guide.add(my_local_guide) # automatically wrapped in an AutoCallable
To specify a median callable, you can instead:
def my_local_median(*args, **kwargs) ... guide.add(AutoCallable(model, my_local_guide, my_local_median))
For more complex guides that need e.g. access to plates, users should instead subclass
AutoGuide
.Parameters: - model (callable) – a Pyro model
- guide (callable) – a Pyro guide (typically over only part of the model)
- median (callable) – an optional callable returning a dict mapping sample site name to computed median tensor.
AutoNormal¶
-
class
AutoNormal
(model, *, init_loc_fn=<function init_to_feasible>, init_scale=0.1, create_plates=None)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
This implementation of
AutoGuide
uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.It should be equivalent to :class: AutoDiagonalNormal , but with more convenient site names and with better support for
TraceMeanField_ELBO
.In
AutoDiagonalNormal
, if your model has N named parameters with dimensions k_i and sum k_i = D, you get a single vector of length D for your mean, and a single vector of length D for sigmas. This guide gives you N distinct normals that you can call by name.Usage:
guide = AutoNormal(model) svi = SVI(model, guide, ...)
Parameters: - model (callable) – A Pyro model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
- create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning apyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
-
forward
(*args, **kwargs)[source]¶ An automatic guide with the same
*args, **kwargs
as the basemodel
.Note
This method is used internally by
Module
. Users should instead use__call__()
.Returns: A dict mapping sample site name to sampled value. Return type: dict
-
median
(*args, **kwargs)[source]¶ Returns the posterior median value of each latent variable.
Returns: A dict mapping sample site name to median tensor. Return type: dict
-
quantiles
(quantiles, *args, **kwargs)[source]¶ Returns posterior quantiles each latent variable. Example:
print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters: quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1. Returns: A dict mapping sample site name to a tensor of quantile values. Return type: dict
-
scale_constraint
= SoftplusPositive(lower_bound=0.0)¶
AutoDelta¶
-
class
AutoDelta
(model, init_loc_fn=<function init_to_median>, *, create_plates=None)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
This implementation of
AutoGuide
uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Note
This class does MAP inference in constrained space.
Usage:
guide = AutoDelta(model) svi = SVI(model, guide, ...)
Latent variables are initialized using
init_loc_fn()
. To change the default behavior, create a custominit_loc_fn()
as described in Initialization , for example:def my_init_fn(site): if site["name"] == "level": return torch.tensor([-1., 0., 1.]) if site["name"] == "concentration": return torch.ones(k) return init_to_sample(site)
Parameters: - model (callable) – A Pyro model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning apyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
AutoContinuous¶
-
class
AutoContinuous
(model, init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].
This uses
torch.distributions.transforms
to transform each constrained latent variable to an unconstrained space, then concatenate all variables into a single unconstrained latent variable. Each derived class implements aget_posterior()
method returning a distribution over this single unconstrained latent variable.Assumes model structure and latent dimension are fixed, and all latent variables are continuous.
Parameters: model (callable) – a Pyro model Reference:
- [1] Automatic Differentiation Variational Inference,
- Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
Parameters: - model (callable) – A Pyro model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
-
forward
(*args, **kwargs)[source]¶ An automatic guide with the same
*args, **kwargs
as the basemodel
.Note
This method is used internally by
Module
. Users should instead use__call__()
.Returns: A dict mapping sample site name to sampled value. Return type: dict
-
get_base_dist
()[source]¶ Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution
. This should not depend on the model’s *args, **kwargs.posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
Returns: TorchDistribution
instance representing the base distribution.
-
get_transform
(*args, **kwargs)[source]¶ Returns the transform applied to the base distribution when the posterior is reparameterized as a
TransformedDistribution
. This may depend on the model’s *args, **kwargs.posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
Returns: a Transform
instance.
-
median
(*args, **kwargs)[source]¶ Returns the posterior median value of each latent variable.
Returns: A dict mapping sample site name to median tensor. Return type: dict
-
quantiles
(quantiles, *args, **kwargs)[source]¶ Returns posterior quantiles each latent variable. Example:
print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters: quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1. Returns: A dict mapping sample site name to a tensor of quantile values. Return type: dict
AutoMultivariateNormal¶
-
class
AutoMultivariateNormal
(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoMultivariateNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized by
init_loc_fn()
and the Cholesky factor is initialized to the identity times a small factor.Parameters: - model (callable) – A generative model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
-
scale_tril_constraint
= SoftplusLowerCholesky()¶
AutoDiagonalNormal¶
-
class
AutoDiagonalNormal
(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoDiagonalNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized to zero and the scale is initialized to the identity times a small factor.
Parameters: - model (callable) – A generative model.
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
-
scale_constraint
= SoftplusPositive(lower_bound=0.0)¶
AutoLowRankMultivariateNormal¶
-
class
AutoLowRankMultivariateNormal
(model, init_loc_fn=<function init_to_median>, init_scale=0.1, rank=None)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoLowRankMultivariateNormal(model, rank=10) svi = SVI(model, guide, ...)
By default the
cov_diag
is initialized to a small constant and thecov_factor
is initialized randomly such that on averagecov_factor.matmul(cov_factor.t())
has the same scale ascov_diag
.Parameters: - model (callable) – A generative model.
- rank (int or None) – The rank of the low-rank part of the covariance matrix.
Defaults to approximately
sqrt(latent dim)
. - init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- init_scale (float) – Approximate initial scale for the standard deviation of each (unconstrained transformed) latent variable.
-
scale_constraint
= SoftplusPositive(lower_bound=0.0)¶
AutoNormalizingFlow¶
-
class
AutoNormalizingFlow
(model, init_transform_fn)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a Diagonal Normal distribution transformed via a sequence of bijective transforms (e.g. variousTransformModule
subclasses) to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
transform_init = partial(iterated, block_autoregressive, repeats=2) guide = AutoNormalizingFlow(model, transform_init) svi = SVI(model, guide, ...)
Parameters: - model (callable) – a generative model
- init_transform_fn – a callable which when provided with the latent
dimension returns an instance of
Transform
, orTransformModule
if the transform has trainable params.
AutoIAFNormal¶
-
class
AutoIAFNormal
(model, hidden_dim=None, init_loc_fn=None, num_transforms=1, **init_transform_kwargs)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoNormalizingFlow
This implementation of
AutoContinuous
uses a Diagonal Normal distribution transformed via aAffineAutoregressive
to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoIAFNormal(model, hidden_dim=latent_dim) svi = SVI(model, guide, ...)
Parameters: - model (callable) – a generative model
- hidden_dim (list[int]) – number of hidden dimensions in the IAF
- init_loc_fn (callable) –
A per-site initialization function. See Initialization section for available functions.
Warning
This argument is only to preserve backwards compatibility and has no effect in practice.
- num_transforms (int) – number of
AffineAutoregressive
transforms to use in sequence. - init_transform_kwargs – other keyword arguments taken by
affine_autoregressive()
.
AutoLaplaceApproximation¶
-
class
AutoLaplaceApproximation
(model, init_loc_fn=<function init_to_median>)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoContinuous
Laplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.
Usage:
delta_guide = AutoLaplaceApproximation(model) svi = SVI(model, delta_guide, ...) # ...then train the delta_guide... guide = delta_guide.laplace_approximation()
By default the mean vector is initialized to an empirical prior median.
Parameters: - model (callable) – a generative model
- init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
-
laplace_approximation
(*args, **kwargs)[source]¶ Returns a
AutoMultivariateNormal
instance whose posterior’s loc and scale_tril are given by Laplace approximation.
AutoDiscreteParallel¶
AutoStructured¶
-
class
AutoStructured
(model, *, conditionals: Dict[str, Union[str, Callable]] = 'normal', dependencies: Dict[str, Dict[str, Union[str, Callable]]] = 'linear', init_loc_fn=<function init_to_feasible>, init_scale=0.1, create_plates=None)[source]¶ Bases:
pyro.infer.autoguide.guides.AutoGuide
Structured guide whose conditional distributions are Delta, Normal, MultivariateNormal, or by a callable, and whose latent variables can depend on each other either linearly (in unconstrained space) or via shearing by a callable.
Usage:
def model(data): x = pyro.sample("x", dist.LogNormal(0, 1)) with pyro.plate("plate", len(data)): y = pyro.sample("y", dist.Normal(0, 1)) pyro.sample("z", dist.Normal(y, x), obs=data) guide = AutoStructured( model=model, conditionals={"x": "normal", "y": "normal"}, dependencies={"x": {"y": "linear"}}, )
Once trained, this guide can be used with
StructuredReparam
to precondition a model for use in HMC and NUTS inference.Note
If you declare a dependency of a high-dimensional downstream variable on a low-dimensional upstream variable, you may want to use a lower learning rate for that weight, e.g.:
def optim_config(param_name): config = {"lr": 0.01} if "deps.my_downstream.my_upstream" in param_name: config["lr"] *= 0.1 return config adam = pyro.optim.Adam(optim_config)
Parameters: - model (callable) – A Pyro model.
- conditionals – Family of distribution with which to model each latent
variable’s conditional posterior. This should be a dict mapping each
latent variable name to either a string in (“delta”, “normal”, or
“mvn”) or to a callable that returns a sample from a zero mean (or
approximately centered) noise distribution (such callables typically
call
pyro.param()
andpyro.sample()
internally). - dependencies – Dict mapping each site name to a dict of its upstream
dependencies; each inner dict maps upstream site name to either the
string “linear” or a callable that maps a flattened upstream
perturbation to flattened downstream perturbation. The string
“linear” is equivalent to
nn.Linear(upstream.numel(), downstream.numel(), bias=False)
. Dependencies must not contain cycles or self-loops. - init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
- create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning apyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
-
get_deltas
¶
-
scale_constraint
= SoftplusPositive(lower_bound=0.0)¶
-
scale_tril_constraint
= SoftplusLowerCholesky()¶
Initialization¶
The pyro.infer.autoguide.initialization module contains initialization functions for automatic guides.
The standard interface for initialization is a function that inputs a Pyro
trace site
dict and returns an appropriately sized value
to serve
as an initial constrained value for a guide estimate.
-
init_to_feasible
(site=None)[source]¶ Initialize to an arbitrary feasible point, ignoring distribution parameters.
-
init_to_median
(site=None, num_samples=15, *, fallback: Optional[Callable] = <function init_to_feasible>)[source]¶ Initialize to the prior median; fallback to
fallback
(defaults toinit_to_feasible()
) if mean is undefined.Parameters: fallback (callable) – Fallback init strategy, for sites not specified in values
.Raises: ValueError – If fallback=None
and no value for a site is given invalues
.
-
init_to_mean
(site=None, *, fallback: Optional[Callable] = <function init_to_median>)[source]¶ Initialize to the prior mean; fallback to
fallback
(defaults toinit_to_median()
) if mean is undefined.Parameters: fallback (callable) – Fallback init strategy, for sites not specified in values
.Raises: ValueError – If fallback=None
and no value for a site is given invalues
.
-
init_to_uniform
(site: Optional[dict] = None, radius: float = 2.0)[source]¶ Initialize to a random point in the area
(-radius, radius)
of unconstrained domain.Parameters: radius (float) – specifies the range to draw an initial point in the unconstrained domain.
-
init_to_value
(site: Optional[dict] = None, values: dict = {}, *, fallback: Optional[Callable] = <function init_to_uniform>)[source]¶ Initialize to the value specified in
values
. Fallback tofallback
(defaults toinit_to_uniform()
) strategy for sites not appearing invalues
.Parameters: - values (dict) – dictionary of initial values keyed by site name.
- fallback (callable) – Fallback init strategy, for sites not specified
in
values
.
Raises: ValueError – If
fallback=None
and no value for a site is given invalues
.
-
init_to_generated
(site=None, generate=<function <lambda>>)[source]¶ Initialize to another initialization strategy returned by the callback
generate
which is called once per model execution.This is like
init_to_value()
but can produce different (e.g. random) values once per model execution. For example to generate values and returninit_to_value
you could define:def generate(): values = {"x": torch.randn(100), "y": torch.rand(5)} return init_to_value(values=values) my_init_fn = init_to_generated(generate=generate)
Parameters: generate (callable) – A callable returning another initialization function, e.g. returning an init_to_value(values={...})
populated with a dictionary of random samples.
-
class
InitMessenger
(init_fn)[source]¶ Bases:
pyro.poutine.messenger.Messenger
Initializes a site by replacing
.sample()
calls with values drawn from an initialization strategy. This is mainly for internal use by autoguide classes.Parameters: init_fn (callable) – An initialization function.