Automatic Guide Generation¶
AutoGuide¶
- class AutoGuide(model, *, create_plates=None)[source]¶
Bases:
pyro.nn.module.PyroModule
Base class for automatic guides.
Derived classes must implement the
forward()
method, with the same*args, **kwargs
as the basemodel
.Auto guides can be used individually or combined in an
AutoGuideList
object.- Parameters
model (callable) – A pyro model.
create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning apyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
- property model¶
- call(*args, **kwargs)[source]¶
Method that calls
forward()
and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlikeforward()
, this method can be traced bytorch.jit.trace_module()
.Warning
This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue <https://github.com/pytorch/pytorch/issues/27743>_.
- sample_latent(**kwargs)[source]¶
Samples an encoded latent given the same
*args, **kwargs
as the basemodel
.
AutoGuideList¶
- class AutoGuideList(model, *, create_plates=None)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoGuide
,torch.nn.modules.container.ModuleList
Container class to combine multiple automatic or custom guides.
Example usage:
guide = AutoGuideList(my_model) guide.append(AutoDiagonalNormal(poutine.block(model, hide=["assignment"]))) guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) svi = SVI(model, guide, optim, Trace_ELBO())
- Parameters
model (callable) – a Pyro model
- append(part)[source]¶
Add an automatic or custom guide for part of the model. The guide should have been created by blocking the model to restrict to a subset of sample sites. No two parts should operate on any one sample site.
- Parameters
part (AutoGuide or callable) – a partial guide to add
- forward(*args, **kwargs)[source]¶
A composite guide with the same
*args, **kwargs
as the basemodel
.Note
This method is used internally by
Module
. Users should instead use__call__()
.- Returns
A dict mapping sample site name to sampled value.
- Return type
- median(*args, **kwargs)[source]¶
Returns the posterior median value of each latent variable.
- Returns
A dict mapping sample site name to median tensor.
- Return type
AutoCallable¶
- class AutoCallable(model, guide, median=<function AutoCallable.<lambda>>)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoGuide
AutoGuide
wrapper for simple callable guides.This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:
def my_local_guide(*args, **kwargs): ... guide = AutoGuideList(model) guide.append(AutoDelta(poutine.block(model, expose=['my_global_param'])) guide.append(my_local_guide) # automatically wrapped in an AutoCallable
To specify a median callable, you can instead:
def my_local_median(*args, **kwargs) ... guide.append(AutoCallable(model, my_local_guide, my_local_median))
For more complex guides that need e.g. access to plates, users should instead subclass
AutoGuide
.- Parameters
model (callable) – a Pyro model
guide (callable) – a Pyro guide (typically over only part of the model)
median (callable) – an optional callable returning a dict mapping sample site name to computed median tensor.
AutoNormal¶
- class AutoNormal(model, *, init_loc_fn=<function init_to_feasible>, init_scale=0.1, create_plates=None)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoGuide
This implementation of
AutoGuide
uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.It should be equivalent to :class: AutoDiagonalNormal , but with more convenient site names and with better support for
TraceMeanField_ELBO
.In
AutoDiagonalNormal
, if your model has N named parameters with dimensions k_i and sum k_i = D, you get a single vector of length D for your mean, and a single vector of length D for sigmas. This guide gives you N distinct normals that you can call by name.Usage:
guide = AutoNormal(model) svi = SVI(model, guide, ...)
- Parameters
model (callable) – A Pyro model.
init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning apyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
- scale_constraint = SoftplusPositive(lower_bound=0.0)¶
- forward(*args, **kwargs)[source]¶
An automatic guide with the same
*args, **kwargs
as the basemodel
.Note
This method is used internally by
Module
. Users should instead use__call__()
.- Returns
A dict mapping sample site name to sampled value.
- Return type
- median(*args, **kwargs)[source]¶
Returns the posterior median value of each latent variable.
- Returns
A dict mapping sample site name to median tensor.
- Return type
- quantiles(quantiles, *args, **kwargs)[source]¶
Returns posterior quantiles each latent variable. Example:
print(guide.quantiles([0.05, 0.5, 0.95]))
- Parameters
quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.
- Returns
A dict mapping sample site name to a tensor of quantile values.
- Return type
AutoDelta¶
- class AutoDelta(model, init_loc_fn=<function init_to_median>, *, create_plates=None)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoGuide
This implementation of
AutoGuide
uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Note
This class does MAP inference in constrained space.
Usage:
guide = AutoDelta(model) svi = SVI(model, guide, ...)
Latent variables are initialized using
init_loc_fn()
. To change the default behavior, create a custominit_loc_fn()
as described in Initialization , for example:def my_init_fn(site): if site["name"] == "level": return torch.tensor([-1., 0., 1.]) if site["name"] == "concentration": return torch.ones(k) return init_to_sample(site)
- Parameters
model (callable) – A Pyro model.
init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning apyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
- forward(*args, **kwargs)[source]¶
An automatic guide with the same
*args, **kwargs
as the basemodel
.Note
This method is used internally by
Module
. Users should instead use__call__()
.- Returns
A dict mapping sample site name to sampled value.
- Return type
AutoContinuous¶
- class AutoContinuous(model, init_loc_fn=<function init_to_median>)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoGuide
Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].
This uses
torch.distributions.transforms
to transform each constrained latent variable to an unconstrained space, then concatenate all variables into a single unconstrained latent variable. Each derived class implements aget_posterior()
method returning a distribution over this single unconstrained latent variable.Assumes model structure and latent dimension are fixed, and all latent variables are continuous.
- Parameters
model (callable) – a Pyro model
Reference:
- [1] Automatic Differentiation Variational Inference,
Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
- Parameters
model (callable) – A Pyro model.
init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- get_base_dist()[source]¶
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution
. This should not depend on the model’s *args, **kwargs.posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
- Returns
TorchDistribution
instance representing the base distribution.
- get_transform(*args, **kwargs)[source]¶
Returns the transform applied to the base distribution when the posterior is reparameterized as a
TransformedDistribution
. This may depend on the model’s *args, **kwargs.posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
- Returns
a
Transform
instance.
- sample_latent(*args, **kwargs)[source]¶
Samples an encoded latent given the same
*args, **kwargs
as the basemodel
.
- forward(*args, **kwargs)[source]¶
An automatic guide with the same
*args, **kwargs
as the basemodel
.Note
This method is used internally by
Module
. Users should instead use__call__()
.- Returns
A dict mapping sample site name to sampled value.
- Return type
- median(*args, **kwargs)[source]¶
Returns the posterior median value of each latent variable.
- Returns
A dict mapping sample site name to median tensor.
- Return type
- quantiles(quantiles, *args, **kwargs)[source]¶
Returns posterior quantiles each latent variable. Example:
print(guide.quantiles([0.05, 0.5, 0.95]))
- Parameters
quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.
- Returns
A dict mapping sample site name to a tensor of quantile values.
- Return type
AutoMultivariateNormal¶
- class AutoMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoMultivariateNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized by
init_loc_fn()
and the Cholesky factor is initialized to the identity times a small factor.- Parameters
model (callable) – A generative model.
init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
- scale_constraint = SoftplusPositive(lower_bound=0.0)¶
- scale_tril_constraint = UnitLowerCholesky()¶
AutoDiagonalNormal¶
- class AutoDiagonalNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoDiagonalNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized to zero and the scale is initialized to the identity times a small factor.
- Parameters
model (callable) – A generative model.
init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
- scale_constraint = SoftplusPositive(lower_bound=0.0)¶
AutoLowRankMultivariateNormal¶
- class AutoLowRankMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1, rank=None)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoLowRankMultivariateNormal(model, rank=10) svi = SVI(model, guide, ...)
By default the
cov_diag
is initialized to a small constant and thecov_factor
is initialized randomly such that on averagecov_factor.matmul(cov_factor.t())
has the same scale ascov_diag
.- Parameters
model (callable) – A generative model.
rank (int or None) – The rank of the low-rank part of the covariance matrix. Defaults to approximately
sqrt(latent dim)
.init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
init_scale (float) – Approximate initial scale for the standard deviation of each (unconstrained transformed) latent variable.
- scale_constraint = SoftplusPositive(lower_bound=0.0)¶
AutoNormalizingFlow¶
- class AutoNormalizingFlow(model, init_transform_fn)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoContinuous
This implementation of
AutoContinuous
uses a Diagonal Normal distribution transformed via a sequence of bijective transforms (e.g. variousTransformModule
subclasses) to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
transform_init = partial(iterated, block_autoregressive, repeats=2) guide = AutoNormalizingFlow(model, transform_init) svi = SVI(model, guide, ...)
- Parameters
model (callable) – a generative model
init_transform_fn – a callable which when provided with the latent dimension returns an instance of
Transform
, orTransformModule
if the transform has trainable params.
AutoIAFNormal¶
- class AutoIAFNormal(model, hidden_dim=None, init_loc_fn=None, num_transforms=1, **init_transform_kwargs)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoNormalizingFlow
This implementation of
AutoContinuous
uses a Diagonal Normal distribution transformed via aAffineAutoregressive
to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoIAFNormal(model, hidden_dim=latent_dim) svi = SVI(model, guide, ...)
- Parameters
model (callable) – a generative model
hidden_dim (list[int]) – number of hidden dimensions in the IAF
init_loc_fn (callable) –
A per-site initialization function. See Initialization section for available functions.
Warning
This argument is only to preserve backwards compatibility and has no effect in practice.
num_transforms (int) – number of
AffineAutoregressive
transforms to use in sequence.init_transform_kwargs – other keyword arguments taken by
affine_autoregressive()
.
AutoLaplaceApproximation¶
- class AutoLaplaceApproximation(model, init_loc_fn=<function init_to_median>)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoContinuous
Laplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.
Usage:
delta_guide = AutoLaplaceApproximation(model) svi = SVI(model, delta_guide, ...) # ...then train the delta_guide... guide = delta_guide.laplace_approximation()
By default the mean vector is initialized to an empirical prior median.
- Parameters
model (callable) – a generative model
init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
- laplace_approximation(*args, **kwargs)[source]¶
Returns a
AutoMultivariateNormal
instance whose posterior’s loc and scale_tril are given by Laplace approximation.
AutoDiscreteParallel¶
- class AutoDiscreteParallel(model, *, create_plates=None)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoGuide
A discrete mean-field guide that learns a latent discrete distribution for each discrete site in the model.
AutoStructured¶
- class AutoStructured(model, *, conditionals: Union[str, Dict[str, Union[str, Callable]]] = 'mvn', dependencies: Union[str, Dict[str, Dict[str, Union[str, Callable]]]] = 'linear', init_loc_fn: Callable = <function init_to_feasible>, init_scale: float = 0.1, create_plates: Optional[Callable] = None)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoGuide
Structured guide whose conditional distributions are Delta, Normal, MultivariateNormal, or by a callable, and whose latent variables can depend on each other either linearly (in unconstrained space) or via shearing by a callable.
Usage:
def model(data): x = pyro.sample("x", dist.LogNormal(0, 1)) with pyro.plate("plate", len(data)): y = pyro.sample("y", dist.Normal(0, 1)) pyro.sample("z", dist.Normal(y, x), obs=data) # Either fully automatic... guide = AutoStructured(model) # ...or with specified conditional and dependency types... guide = AutoStructured( model, conditionals="normal", dependencies="linear" ) # ...or with custom dependency structure and distribution types. guide = AutoStructured( model=model, conditionals={"x": "normal", "y": "delta"}, dependencies={"x": {"y": "linear"}}, )
Once trained, this guide can be used with
StructuredReparam
to precondition a model for use in HMC and NUTS inference.Note
If you declare a dependency of a high-dimensional downstream variable on a low-dimensional upstream variable, you may want to use a lower learning rate for that weight, e.g.:
def optim_config(param_name): config = {"lr": 0.01} if "deps.my_downstream.my_upstream" in param_name: config["lr"] *= 0.1 return config adam = pyro.optim.Adam(optim_config)
- Parameters
model (callable) – A Pyro model.
conditionals – Either a single distribution type or a dict mapping each latent variable name to a distribution type. A distribution type is either a string in {“delta”, “normal”, “mvn”} or a callable that returns a sample from a zero mean (or approximately centered) noise distribution (such callables typically call
pyro.param()
andpyro.sample()
internally).dependencies – Dependency type, or a dict mapping each site name to a dict mapping its upstream dependencies to dependency types. If only a dependecy type is provided, dependency structure will be inferred. A dependency type is either the string “linear” or a callable that maps a flattened upstream perturbation to flattened downstream perturbation. The string “linear” is equivalent to
nn.Linear(upstream.numel(), downstream.numel(), bias=False)
. Dependencies must not contain cycles or self-loops.init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
create_plates (callable) – An optional function inputing the same
*args,**kwargs
asmodel()
and returning apyro.plate
or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.
- scale_constraint = SoftplusPositive(lower_bound=0.0)¶
- scale_tril_constraint = SoftplusLowerCholesky()¶
- get_deltas()¶
AutoGaussian¶
- class AutoGaussian(*args, **kwargs)[source]¶
Bases:
pyro.infer.autoguide.guides.AutoGuide
Gaussian guide with optimal conditional independence structure.
This is equivalent to a full rank
AutoMultivariateNormal
guide, but with a sparse precision matrix determined by dependencies and plates in the model [1]. Depending on model structure, this can have asymptotically better statistical efficiency thanAutoMultivariateNormal
.This guide implements multiple backends for computation. All backends use the same statistically optimal parametrization. The default “dense” backend has computational complexity similar to
AutoMultivariateNormal
. The experimental “funsor” backend can be asymptotically cheaper in terms of time and space (using Gaussian tensor variable elimination [2,3]), but incurs large constant overhead. The “funsor” backend requires funsor which can be installed viapip install pyro-ppl[funsor]
.The guide currently does not depend on the model’s
*args, **kwargs
.Example:
guide = AutoGaussian(model) svi = SVI(model, guide, ...)
Example using experimental funsor backend:
!pip install pyro-ppl[funsor] guide = AutoGaussian(model, backend="funsor") svi = SVI(model, guide, ...)
References
- [1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)
“Faithful inversion of generative models for effective amortized inference” https://dl.acm.org/doi/10.5555/3327144.3327229
- [2] F.Obermeyer, E.Bingham, M.Jankowiak, J.Chiu, N.Pradhan, A.M.Rush, N.Goodman
(2019) “Tensor Variable Elimination for Plated Factor Graphs” http://proceedings.mlr.press/v97/obermeyer19a/obermeyer19a.pdf
- [3] F. Obermeyer, E. Bingham, M. Jankowiak, D. Phan, J. P. Chen
(2019) “Functional Tensors for Probabilistic Programming” https://arxiv.org/abs/1910.10775
- Parameters
model (callable) – A Pyro model.
init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
backend (str) – Back end for performing Gaussian tensor variable elimination. Defaults to “dense”; other options include “funsor”.
- scale_constraint = SoftplusPositive(lower_bound=0.0)¶
- forward(*args, **kwargs) Dict[str, torch.Tensor] [source]¶
- median(*args, **kwargs) Dict[str, torch.Tensor] [source]¶
Returns the posterior median value of each latent variable.
- Returns
A dict mapping sample site name to median tensor.
- Return type
AutoMessenger¶
- class AutoMessenger(model: Callable, *, amortized_plates: Tuple[str, ...] = ())[source]¶
Bases:
pyro.poutine.guide.GuideMessenger
,pyro.nn.module.PyroModule
Base class for
GuideMessenger
autoguides.- Parameters
model (callable) – A Pyro model.
amortized_plates (tuple) – A tuple of names of plates over which guide parameters should be shared. This is useful for subsampling, where a guide parameter can be shared across all plates.
- call(*args, **kwargs)[source]¶
Method that calls
forward()
and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlikeforward()
, this method can be traced bytorch.jit.trace_module()
.Warning
This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue <https://github.com/pytorch/pytorch/issues/27743>_.
AutoNormalMessenger¶
- class AutoNormalMessenger(model: Callable, *, init_loc_fn: Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), init_scale: float = 0.1, amortized_plates: Tuple[str, ...] = ())[source]¶
Bases:
pyro.infer.autoguide.effect.AutoMessenger
AutoMessenger
with mean-field normal posterior.The mean-field posterior at any site is a transformed normal distribution. This posterior is equivalent to
AutoNormal
orAutoDiagonalNormal
, but allows customization via subclassing.Derived classes may override the
get_posterior()
behavior at particular sites and use the mean-field normal behavior simply as a default, e.g.:def model(data): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.Normal(0, 1)) c = pyro.sample("c", dist.Normal(a + b, 1)) pyro.sample("obs", dist.Normal(c, 1), obs=data) class MyGuideMessenger(AutoNormalMessenger): def get_posterior(self, name, prior): if name == "c": # Use a custom distribution at site c. bias = pyro.param("c_bias", lambda: torch.zeros(())) weight = pyro.param("c_weight", lambda: torch.ones(()), constraint=constraints.positive) scale = pyro.param("c_scale", lambda: torch.ones(()), constraint=constraints.positive) a = self.upstream_value("a") b = self.upstream_value("b") loc = bias + weight * (a + b) return dist.Normal(loc, scale) # Fall back to mean field. return super().get_posterior(name, prior)
Note that above we manually computed
loc = bias + weight * (a + b)
. Alternatively we could reuse the model-side computation by settingloc = bias + weight * prior.loc
:class MyGuideMessenger_v2(AutoNormalMessenger): def get_posterior(self, name, prior): if name == "c": # Use a custom distribution at site c. bias = pyro.param("c_bias", lambda: torch.zeros(())) scale = pyro.param("c_scale", lambda: torch.ones(()), constraint=constraints.positive) weight = pyro.param("c_weight", lambda: torch.ones(()), constraint=constraints.positive) loc = bias + weight * prior.loc return dist.Normal(loc, scale) # Fall back to mean field. return super().get_posterior(name, prior)
- Parameters
model (callable) – A Pyro model.
init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
amortized_plates (tuple) – A tuple of names of plates over which guide parameters should be shared. This is useful for subsampling, where a guide parameter can be shared across all plates.
- get_posterior(name: str, prior: pyro.distributions.distribution.Distribution) Union[pyro.distributions.distribution.Distribution, torch.Tensor] [source]¶
AutoHierarchicalNormalMessenger¶
- class AutoHierarchicalNormalMessenger(model: Callable, *, init_loc_fn: Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), init_scale: float = 0.1, amortized_plates: Tuple[str, ...] = (), init_weight: float = 1.0, hierarchical_sites: Optional[list] = None)[source]¶
Bases:
pyro.infer.autoguide.effect.AutoNormalMessenger
AutoMessenger
with mean-field normal posterior conditional on all dependencies.The mean-field posterior at any site is a transformed normal distribution, the mean of which depends on the value of that site given its dependencies in the model:
loc_total = loc + transform.inv(prior.mean) * weight
Where the value of
prior.mean
is conditional on upstream sites in the model,loc
is independent component of the mean in the untransformed space,weight
is element-wise factor that scales the prior mean. This approach doesn’t work for distributions that don’t have the mean.Derived classes may override particular sites and use this simply as a default, see
AutoNormalMessenger
documentation for example.- Parameters
model (callable) – A Pyro model.
init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
init_weight (float) – Initial value for the weight of the contribution of hierarchical sites to posterior mean for each latent variable.
hierarchical_sites (list) – List of latent variables (model sites) that have hierarchical dependencies. If None, all sites are assumed to have hierarchical dependencies. If None, for the sites that don’t have upstream sites, the loc and weight of the guide are representing/learning deviation from the prior.
- weight_type = 'element-wise'¶
- get_posterior(name: str, prior: pyro.distributions.distribution.Distribution) Union[pyro.distributions.distribution.Distribution, torch.Tensor] [source]¶
AutoRegressiveMessenger¶
- class AutoRegressiveMessenger(model: Callable, *, init_loc_fn: Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), init_scale: float = 0.1, amortized_plates: Tuple[str, ...] = ())[source]¶
Bases:
pyro.infer.autoguide.effect.AutoMessenger
AutoMessenger
with recursively affine-transformed priors using prior dependency structure.The posterior at any site is a learned affine transform of the prior, conditioned on upstream posterior samples. The affine transform operates in unconstrained space. This supports only continuous latent variables.
Derived classes may override the
get_posterior()
behavior at particular sites and use the regressive behavior simply as a default, e.g.:class MyGuideMessenger(AutoRegressiveMessenger): def get_posterior(self, name, prior): if name == "x": # Use a custom distribution at site x. loc = pyro.param("x_loc", lambda: torch.zeros(prior.shape())) scale = pyro.param("x_scale", lambda: torch.ones(prior.shape())), constraint=constraints.positive return dist.Normal(loc, scale).to_event(prior.event_dim()) # Fall back to autoregressive. return super().get_posterior(name, prior)
Warning
This guide currently does not support jit-based elbos.
- Parameters
model (callable) – A Pyro model.
init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.
init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.
amortized_plates (tuple) – A tuple of names of plates over which guide parameters should be shared. This is useful for subsampling, where a guide parameter can be shared across all plates.
- get_posterior(name: str, prior: pyro.distributions.distribution.Distribution) Union[pyro.distributions.distribution.Distribution, torch.Tensor] [source]¶
Initialization¶
The pyro.infer.autoguide.initialization module contains initialization functions for automatic guides.
The standard interface for initialization is a function that inputs a Pyro
trace site
dict and returns an appropriately sized value
to serve
as an initial constrained value for a guide estimate.
- init_to_feasible(site=None)[source]¶
Initialize to an arbitrary feasible point, ignoring distribution parameters.
- init_to_median(site=None, num_samples=15, *, fallback: Optional[Callable] = <function init_to_feasible>)[source]¶
Initialize to the prior median; fallback to
fallback
(defaults toinit_to_feasible()
) if mean is undefined.- Parameters
fallback (callable) – Fallback init strategy, for sites not specified in
values
.- Raises
ValueError – If
fallback=None
and no value for a site is given invalues
.
- init_to_mean(site=None, *, fallback: Optional[Callable] = <function init_to_median>)[source]¶
Initialize to the prior mean; fallback to
fallback
(defaults toinit_to_median()
) if mean is undefined.- Parameters
fallback (callable) – Fallback init strategy, for sites not specified in
values
.- Raises
ValueError – If
fallback=None
and no value for a site is given invalues
.
- init_to_uniform(site: Optional[dict] = None, radius: float = 2.0)[source]¶
Initialize to a random point in the area
(-radius, radius)
of unconstrained domain.- Parameters
radius (float) – specifies the range to draw an initial point in the unconstrained domain.
- init_to_value(site: Optional[dict] = None, values: dict = {}, *, fallback: Optional[Callable] = <function init_to_uniform>)[source]¶
Initialize to the value specified in
values
. Fallback tofallback
(defaults toinit_to_uniform()
) strategy for sites not appearing invalues
.- Parameters
values (dict) – dictionary of initial values keyed by site name.
fallback (callable) – Fallback init strategy, for sites not specified in
values
.
- Raises
ValueError – If
fallback=None
and no value for a site is given invalues
.
- init_to_generated(site=None, generate=<function <lambda>>)[source]¶
Initialize to another initialization strategy returned by the callback
generate
which is called once per model execution.This is like
init_to_value()
but can produce different (e.g. random) values once per model execution. For example to generate values and returninit_to_value
you could define:def generate(): values = {"x": torch.randn(100), "y": torch.rand(5)} return init_to_value(values=values) my_init_fn = init_to_generated(generate=generate)
- Parameters
generate (callable) – A callable returning another initialization function, e.g. returning an
init_to_value(values={...})
populated with a dictionary of random samples.
- class InitMessenger(init_fn)[source]¶
Bases:
pyro.poutine.messenger.Messenger
Initializes a site by replacing
.sample()
calls with values drawn from an initialization strategy. This is mainly for internal use by autoguide classes.- Parameters
init_fn (callable) – An initialization function.