Automatic Guide Generation

AutoGuide

class AutoGuide(model, *, create_plates=None)[source]

Bases: pyro.nn.module.PyroModule

Base class for automatic guides.

Derived classes must implement the forward() method, with the same *args, **kwargs as the base model.

Auto guides can be used individually or combined in an AutoGuideList object.

Parameters
  • model (callable) – A pyro model.

  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a pyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.

property model
call(*args, **kwargs)[source]

Method that calls forward() and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlike forward(), this method can be traced by torch.jit.trace_module().

Warning

This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue <https://github.com/pytorch/pytorch/issues/27743>_.

sample_latent(**kwargs)[source]

Samples an encoded latent given the same *args, **kwargs as the base model.

median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns

A dict mapping sample site name to median tensor.

Return type

dict

training: bool

AutoGuideList

class AutoGuideList(model, *, create_plates=None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide, torch.nn.modules.container.ModuleList

Container class to combine multiple automatic or custom guides.

Example usage:

guide = AutoGuideList(my_model)
guide.append(AutoDiagonalNormal(poutine.block(model, hide=["assignment"])))
guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"])))
svi = SVI(model, guide, optim, Trace_ELBO())
Parameters

model (callable) – a Pyro model

append(part)[source]

Add an automatic or custom guide for part of the model. The guide should have been created by blocking the model to restrict to a subset of sample sites. No two parts should operate on any one sample site.

Parameters

part (AutoGuide or callable) – a partial guide to add

add(part)[source]

Deprecated alias for append().

forward(*args, **kwargs)[source]

A composite guide with the same *args, **kwargs as the base model.

Note

This method is used internally by Module. Users should instead use __call__().

Returns

A dict mapping sample site name to sampled value.

Return type

dict

median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns

A dict mapping sample site name to median tensor.

Return type

dict

quantiles(quantiles, *args, **kwargs)[source]

Returns the posterior quantile values of each latent variable.

Parameters

quantiles (list) – A list of requested quantiles between 0 and 1.

Returns

A dict mapping sample site name to quantiles tensor.

Return type

dict

training: bool

AutoCallable

class AutoCallable(model, guide, median=<function AutoCallable.<lambda>>)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

AutoGuide wrapper for simple callable guides.

This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:

def my_local_guide(*args, **kwargs):
    ...

guide = AutoGuideList(model)
guide.append(AutoDelta(poutine.block(model, expose=['my_global_param']))
guide.append(my_local_guide)  # automatically wrapped in an AutoCallable

To specify a median callable, you can instead:

def my_local_median(*args, **kwargs)
    ...

guide.append(AutoCallable(model, my_local_guide, my_local_median))

For more complex guides that need e.g. access to plates, users should instead subclass AutoGuide.

Parameters
  • model (callable) – a Pyro model

  • guide (callable) – a Pyro guide (typically over only part of the model)

  • median (callable) – an optional callable returning a dict mapping sample site name to computed median tensor.

forward(*args, **kwargs)[source]
training: bool

AutoNormal

class AutoNormal(model, *, init_loc_fn=<function init_to_feasible>, init_scale=0.1, create_plates=None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

This implementation of AutoGuide uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

It should be equivalent to :class: AutoDiagonalNormal , but with more convenient site names and with better support for TraceMeanField_ELBO .

In AutoDiagonalNormal , if your model has N named parameters with dimensions k_i and sum k_i = D, you get a single vector of length D for your mean, and a single vector of length D for sigmas. This guide gives you N distinct normals that you can call by name.

Usage:

guide = AutoNormal(model)
svi = SVI(model, guide, ...)
Parameters
  • model (callable) – A Pyro model.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.

  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.

  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a pyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.

scale_constraint = SoftplusPositive(lower_bound=0.0)
forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Note

This method is used internally by Module. Users should instead use __call__().

Returns

A dict mapping sample site name to sampled value.

Return type

dict

median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns

A dict mapping sample site name to median tensor.

Return type

dict

quantiles(quantiles, *args, **kwargs)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters

quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.

Returns

A dict mapping sample site name to a tensor of quantile values.

Return type

dict

training: bool

AutoDelta

class AutoDelta(model, init_loc_fn=<function init_to_median>, *, create_plates=None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

This implementation of AutoGuide uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Note

This class does MAP inference in constrained space.

Usage:

guide = AutoDelta(model)
svi = SVI(model, guide, ...)

Latent variables are initialized using init_loc_fn(). To change the default behavior, create a custom init_loc_fn() as described in Initialization , for example:

def my_init_fn(site):
    if site["name"] == "level":
        return torch.tensor([-1., 0., 1.])
    if site["name"] == "concentration":
        return torch.ones(k)
    return init_to_sample(site)
Parameters
  • model (callable) – A Pyro model.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.

  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a pyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.

forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Note

This method is used internally by Module. Users should instead use __call__().

Returns

A dict mapping sample site name to sampled value.

Return type

dict

median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns

A dict mapping sample site name to median tensor.

Return type

dict

training: bool

AutoContinuous

class AutoContinuous(model, init_loc_fn=<function init_to_median>)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].

This uses torch.distributions.transforms to transform each constrained latent variable to an unconstrained space, then concatenate all variables into a single unconstrained latent variable. Each derived class implements a get_posterior() method returning a distribution over this single unconstrained latent variable.

Assumes model structure and latent dimension are fixed, and all latent variables are continuous.

Parameters

model (callable) – a Pyro model

Reference:

[1] Automatic Differentiation Variational Inference,

Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei

Parameters
  • model (callable) – A Pyro model.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.

get_base_dist()[source]

Returns the base distribution of the posterior when reparameterized as a TransformedDistribution. This should not depend on the model’s *args, **kwargs.

posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
Returns

TorchDistribution instance representing the base distribution.

get_transform(*args, **kwargs)[source]

Returns the transform applied to the base distribution when the posterior is reparameterized as a TransformedDistribution. This may depend on the model’s *args, **kwargs.

posterior = TransformedDistribution(self.get_base_dist(), self.get_transform(*args, **kwargs))
Returns

a Transform instance.

get_posterior(*args, **kwargs)[source]

Returns the posterior distribution.

sample_latent(*args, **kwargs)[source]

Samples an encoded latent given the same *args, **kwargs as the base model.

forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Note

This method is used internally by Module. Users should instead use __call__().

Returns

A dict mapping sample site name to sampled value.

Return type

dict

median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns

A dict mapping sample site name to median tensor.

Return type

dict

quantiles(quantiles, *args, **kwargs)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters

quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.

Returns

A dict mapping sample site name to a tensor of quantile values.

Return type

dict

training: bool

AutoMultivariateNormal

class AutoMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoMultivariateNormal(model)
svi = SVI(model, guide, ...)

By default the mean vector is initialized by init_loc_fn() and the Cholesky factor is initialized to the identity times a small factor.

Parameters
  • model (callable) – A generative model.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.

  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.

scale_constraint = SoftplusPositive(lower_bound=0.0)
scale_tril_constraint = UnitLowerCholesky()
get_base_dist()[source]
get_transform(*args, **kwargs)[source]
get_posterior(*args, **kwargs)[source]

Returns a MultivariateNormal posterior distribution.

training: bool

AutoDiagonalNormal

class AutoDiagonalNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, ...)

By default the mean vector is initialized to zero and the scale is initialized to the identity times a small factor.

Parameters
  • model (callable) – A generative model.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.

  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.

scale_constraint = SoftplusPositive(lower_bound=0.0)
get_base_dist()[source]
get_transform(*args, **kwargs)[source]
get_posterior(*args, **kwargs)[source]

Returns a diagonal Normal posterior distribution.

training: bool

AutoLowRankMultivariateNormal

class AutoLowRankMultivariateNormal(model, init_loc_fn=<function init_to_median>, init_scale=0.1, rank=None)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoLowRankMultivariateNormal(model, rank=10)
svi = SVI(model, guide, ...)

By default the cov_diag is initialized to a small constant and the cov_factor is initialized randomly such that on average cov_factor.matmul(cov_factor.t()) has the same scale as cov_diag.

Parameters
  • model (callable) – A generative model.

  • rank (int or None) – The rank of the low-rank part of the covariance matrix. Defaults to approximately sqrt(latent dim).

  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.

  • init_scale (float) – Approximate initial scale for the standard deviation of each (unconstrained transformed) latent variable.

scale_constraint = SoftplusPositive(lower_bound=0.0)
get_posterior(*args, **kwargs)[source]

Returns a LowRankMultivariateNormal posterior distribution.

training: bool

AutoNormalizingFlow

class AutoNormalizingFlow(model, init_transform_fn)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

This implementation of AutoContinuous uses a Diagonal Normal distribution transformed via a sequence of bijective transforms (e.g. various TransformModule subclasses) to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

transform_init = partial(iterated, block_autoregressive,
                         repeats=2)
guide = AutoNormalizingFlow(model, transform_init)
svi = SVI(model, guide, ...)
Parameters
  • model (callable) – a generative model

  • init_transform_fn – a callable which when provided with the latent dimension returns an instance of Transform , or TransformModule if the transform has trainable params.

get_base_dist()[source]
get_transform(*args, **kwargs)[source]
get_posterior(*args, **kwargs)[source]
training: bool

AutoIAFNormal

class AutoIAFNormal(model, hidden_dim=None, init_loc_fn=None, num_transforms=1, **init_transform_kwargs)[source]

Bases: pyro.infer.autoguide.guides.AutoNormalizingFlow

This implementation of AutoContinuous uses a Diagonal Normal distribution transformed via a AffineAutoregressive to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoIAFNormal(model, hidden_dim=latent_dim)
svi = SVI(model, guide, ...)
Parameters
  • model (callable) – a generative model

  • hidden_dim (list[int]) – number of hidden dimensions in the IAF

  • init_loc_fn (callable) –

    A per-site initialization function. See Initialization section for available functions.

    Warning

    This argument is only to preserve backwards compatibility and has no effect in practice.

  • num_transforms (int) – number of AffineAutoregressive transforms to use in sequence.

  • init_transform_kwargs – other keyword arguments taken by affine_autoregressive().

training: bool

AutoLaplaceApproximation

class AutoLaplaceApproximation(model, init_loc_fn=<function init_to_median>)[source]

Bases: pyro.infer.autoguide.guides.AutoContinuous

Laplace approximation (quadratic approximation) approximates the posterior \(\log p(z | x)\) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.

Usage:

delta_guide = AutoLaplaceApproximation(model)
svi = SVI(model, delta_guide, ...)
# ...then train the delta_guide...
guide = delta_guide.laplace_approximation()

By default the mean vector is initialized to an empirical prior median.

Parameters
  • model (callable) – a generative model

  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.

get_posterior(*args, **kwargs)[source]

Returns a Delta posterior distribution for MAP inference.

laplace_approximation(*args, **kwargs)[source]

Returns a AutoMultivariateNormal instance whose posterior’s loc and scale_tril are given by Laplace approximation.

training: bool

AutoDiscreteParallel

class AutoDiscreteParallel(model, *, create_plates=None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

A discrete mean-field guide that learns a latent discrete distribution for each discrete site in the model.

forward(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Note

This method is used internally by Module. Users should instead use __call__().

Returns

A dict mapping sample site name to sampled value.

Return type

dict

training: bool

AutoStructured

class AutoStructured(model, *, conditionals: Union[str, Dict[str, Union[str, Callable]]] = 'mvn', dependencies: Union[str, Dict[str, Dict[str, Union[str, Callable]]]] = 'linear', init_loc_fn: Callable = <function init_to_feasible>, init_scale: float = 0.1, create_plates: Optional[Callable] = None)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

Structured guide whose conditional distributions are Delta, Normal, MultivariateNormal, or by a callable, and whose latent variables can depend on each other either linearly (in unconstrained space) or via shearing by a callable.

Usage:

def model(data):
    x = pyro.sample("x", dist.LogNormal(0, 1))
    with pyro.plate("plate", len(data)):
        y = pyro.sample("y", dist.Normal(0, 1))
        pyro.sample("z", dist.Normal(y, x), obs=data)

# Either fully automatic...
guide = AutoStructured(model)

# ...or with specified conditional and dependency types...
guide = AutoStructured(
    model, conditionals="normal", dependencies="linear"
)

# ...or with custom dependency structure and distribution types.
guide = AutoStructured(
    model=model,
    conditionals={"x": "normal", "y": "delta"},
    dependencies={"x": {"y": "linear"}},
)

Once trained, this guide can be used with StructuredReparam to precondition a model for use in HMC and NUTS inference.

Note

If you declare a dependency of a high-dimensional downstream variable on a low-dimensional upstream variable, you may want to use a lower learning rate for that weight, e.g.:

def optim_config(param_name):
    config = {"lr": 0.01}
    if "deps.my_downstream.my_upstream" in param_name:
        config["lr"] *= 0.1
    return config

adam = pyro.optim.Adam(optim_config)
Parameters
  • model (callable) – A Pyro model.

  • conditionals – Either a single distribution type or a dict mapping each latent variable name to a distribution type. A distribution type is either a string in {“delta”, “normal”, “mvn”} or a callable that returns a sample from a zero mean (or approximately centered) noise distribution (such callables typically call pyro.param() and pyro.sample() internally).

  • dependencies – Dependency type, or a dict mapping each site name to a dict mapping its upstream dependencies to dependency types. If only a dependecy type is provided, dependency structure will be inferred. A dependency type is either the string “linear” or a callable that maps a flattened upstream perturbation to flattened downstream perturbation. The string “linear” is equivalent to nn.Linear(upstream.numel(), downstream.numel(), bias=False). Dependencies must not contain cycles or self-loops.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.

  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.

  • create_plates (callable) – An optional function inputing the same *args,**kwargs as model() and returning a pyro.plate or iterable of plates. Plates not returned will be created automatically as usual. This is useful for data subsampling.

scale_constraint = SoftplusPositive(lower_bound=0.0)
scale_tril_constraint = SoftplusLowerCholesky()
get_deltas()
forward(*args, **kwargs)[source]
median(*args, **kwargs)[source]
training: bool

AutoGaussian

class AutoGaussian(*args, **kwargs)[source]

Bases: pyro.infer.autoguide.guides.AutoGuide

Gaussian guide with optimal conditional independence structure.

This is equivalent to a full rank AutoMultivariateNormal guide, but with a sparse precision matrix determined by dependencies and plates in the model [1]. Depending on model structure, this can have asymptotically better statistical efficiency than AutoMultivariateNormal .

This guide implements multiple backends for computation. All backends use the same statistically optimal parametrization. The default “dense” backend has computational complexity similar to AutoMultivariateNormal . The experimental “funsor” backend can be asymptotically cheaper in terms of time and space (using Gaussian tensor variable elimination [2,3]), but incurs large constant overhead. The “funsor” backend requires funsor which can be installed via pip install pyro-ppl[funsor].

The guide currently does not depend on the model’s *args, **kwargs.

Example:

guide = AutoGaussian(model)
svi = SVI(model, guide, ...)

Example using experimental funsor backend:

!pip install pyro-ppl[funsor]
guide = AutoGaussian(model, backend="funsor")
svi = SVI(model, guide, ...)

References

[1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)

“Faithful inversion of generative models for effective amortized inference” https://dl.acm.org/doi/10.5555/3327144.3327229

[2] F.Obermeyer, E.Bingham, M.Jankowiak, J.Chiu, N.Pradhan, A.M.Rush, N.Goodman

(2019) “Tensor Variable Elimination for Plated Factor Graphs” http://proceedings.mlr.press/v97/obermeyer19a/obermeyer19a.pdf

[3] F. Obermeyer, E. Bingham, M. Jankowiak, D. Phan, J. P. Chen

(2019) “Functional Tensors for Probabilistic Programming” https://arxiv.org/abs/1910.10775

Parameters
  • model (callable) – A Pyro model.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.

  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.

  • backend (str) – Back end for performing Gaussian tensor variable elimination. Defaults to “dense”; other options include “funsor”.

scale_constraint = SoftplusPositive(lower_bound=0.0)
forward(*args, **kwargs) Dict[str, torch.Tensor][source]
median(*args, **kwargs) Dict[str, torch.Tensor][source]

Returns the posterior median value of each latent variable.

Returns

A dict mapping sample site name to median tensor.

Return type

dict

training: bool

AutoMessenger

class AutoMessenger(model: Callable, *, amortized_plates: Tuple[str, ...] = ())[source]

Bases: pyro.poutine.guide.GuideMessenger, pyro.nn.module.PyroModule

Base class for GuideMessenger autoguides.

Parameters
  • model (callable) – A Pyro model.

  • amortized_plates (tuple) – A tuple of names of plates over which guide parameters should be shared. This is useful for subsampling, where a guide parameter can be shared across all plates.

call(*args, **kwargs)[source]

Method that calls forward() and returns parameter values of the guide as a tuple instead of a dict, which is a requirement for JIT tracing. Unlike forward(), this method can be traced by torch.jit.trace_module().

Warning

This method may be removed once PyTorch JIT tracer starts accepting dict as valid return types. See issue <https://github.com/pytorch/pytorch/issues/27743>_.

training: bool

AutoNormalMessenger

class AutoNormalMessenger(model: Callable, *, init_loc_fn: Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), init_scale: float = 0.1, amortized_plates: Tuple[str, ...] = ())[source]

Bases: pyro.infer.autoguide.effect.AutoMessenger

AutoMessenger with mean-field normal posterior.

The mean-field posterior at any site is a transformed normal distribution. This posterior is equivalent to AutoNormal or AutoDiagonalNormal, but allows customization via subclassing.

Derived classes may override the get_posterior() behavior at particular sites and use the mean-field normal behavior simply as a default, e.g.:

def model(data):
    a = pyro.sample("a", dist.Normal(0, 1))
    b = pyro.sample("b", dist.Normal(0, 1))
    c = pyro.sample("c", dist.Normal(a + b, 1))
    pyro.sample("obs", dist.Normal(c, 1), obs=data)

class MyGuideMessenger(AutoNormalMessenger):
    def get_posterior(self, name, prior):
        if name == "c":
            # Use a custom distribution at site c.
            bias = pyro.param("c_bias", lambda: torch.zeros(()))
            weight = pyro.param("c_weight", lambda: torch.ones(()),
                                constraint=constraints.positive)
            scale = pyro.param("c_scale", lambda: torch.ones(()),
                               constraint=constraints.positive)
            a = self.upstream_value("a")
            b = self.upstream_value("b")
            loc = bias + weight * (a + b)
            return dist.Normal(loc, scale)
        # Fall back to mean field.
        return super().get_posterior(name, prior)

Note that above we manually computed loc = bias + weight * (a + b). Alternatively we could reuse the model-side computation by setting loc = bias + weight * prior.loc:

class MyGuideMessenger_v2(AutoNormalMessenger):
    def get_posterior(self, name, prior):
        if name == "c":
            # Use a custom distribution at site c.
            bias = pyro.param("c_bias", lambda: torch.zeros(()))
            scale = pyro.param("c_scale", lambda: torch.ones(()),
                               constraint=constraints.positive)
            weight = pyro.param("c_weight", lambda: torch.ones(()),
                                constraint=constraints.positive)
            loc = bias + weight * prior.loc
            return dist.Normal(loc, scale)
        # Fall back to mean field.
        return super().get_posterior(name, prior)
Parameters
  • model (callable) – A Pyro model.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.

  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.

  • amortized_plates (tuple) – A tuple of names of plates over which guide parameters should be shared. This is useful for subsampling, where a guide parameter can be shared across all plates.

get_posterior(name: str, prior: pyro.distributions.distribution.Distribution) Union[pyro.distributions.distribution.Distribution, torch.Tensor][source]
median(*args, **kwargs)[source]
training: bool

AutoHierarchicalNormalMessenger

class AutoHierarchicalNormalMessenger(model: Callable, *, init_loc_fn: Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), init_scale: float = 0.1, amortized_plates: Tuple[str, ...] = (), init_weight: float = 1.0, hierarchical_sites: Optional[list] = None)[source]

Bases: pyro.infer.autoguide.effect.AutoNormalMessenger

AutoMessenger with mean-field normal posterior conditional on all dependencies.

The mean-field posterior at any site is a transformed normal distribution, the mean of which depends on the value of that site given its dependencies in the model:

loc_total = loc + transform.inv(prior.mean) * weight

Where the value of prior.mean is conditional on upstream sites in the model, loc is independent component of the mean in the untransformed space, weight is element-wise factor that scales the prior mean. This approach doesn’t work for distributions that don’t have the mean.

Derived classes may override particular sites and use this simply as a default, see AutoNormalMessenger documentation for example.

Parameters
  • model (callable) – A Pyro model.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.

  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.

  • init_weight (float) – Initial value for the weight of the contribution of hierarchical sites to posterior mean for each latent variable.

  • hierarchical_sites (list) – List of latent variables (model sites) that have hierarchical dependencies. If None, all sites are assumed to have hierarchical dependencies. If None, for the sites that don’t have upstream sites, the loc and weight of the guide are representing/learning deviation from the prior.

weight_type = 'element-wise'
get_posterior(name: str, prior: pyro.distributions.distribution.Distribution) Union[pyro.distributions.distribution.Distribution, torch.Tensor][source]
median(*args, **kwargs)[source]
training: bool

AutoRegressiveMessenger

class AutoRegressiveMessenger(model: Callable, *, init_loc_fn: Callable = functools.partial(<function init_to_mean>, fallback=<function init_to_feasible>), init_scale: float = 0.1, amortized_plates: Tuple[str, ...] = ())[source]

Bases: pyro.infer.autoguide.effect.AutoMessenger

AutoMessenger with recursively affine-transformed priors using prior dependency structure.

The posterior at any site is a learned affine transform of the prior, conditioned on upstream posterior samples. The affine transform operates in unconstrained space. This supports only continuous latent variables.

Derived classes may override the get_posterior() behavior at particular sites and use the regressive behavior simply as a default, e.g.:

class MyGuideMessenger(AutoRegressiveMessenger):
    def get_posterior(self, name, prior):
        if name == "x":
            # Use a custom distribution at site x.
            loc = pyro.param("x_loc", lambda: torch.zeros(prior.shape()))
            scale = pyro.param("x_scale", lambda: torch.ones(prior.shape())),
                               constraint=constraints.positive
            return dist.Normal(loc, scale).to_event(prior.event_dim())
        # Fall back to autoregressive.
        return super().get_posterior(name, prior)

Warning

This guide currently does not support jit-based elbos.

Parameters
  • model (callable) – A Pyro model.

  • init_loc_fn (callable) – A per-site initialization function. See Initialization section for available functions.

  • init_scale (float) – Initial scale for the standard deviation of each (unconstrained transformed) latent variable.

  • amortized_plates (tuple) – A tuple of names of plates over which guide parameters should be shared. This is useful for subsampling, where a guide parameter can be shared across all plates.

get_posterior(name: str, prior: pyro.distributions.distribution.Distribution) Union[pyro.distributions.distribution.Distribution, torch.Tensor][source]
training: bool

Initialization

The pyro.infer.autoguide.initialization module contains initialization functions for automatic guides.

The standard interface for initialization is a function that inputs a Pyro trace site dict and returns an appropriately sized value to serve as an initial constrained value for a guide estimate.

init_to_feasible(site=None)[source]

Initialize to an arbitrary feasible point, ignoring distribution parameters.

init_to_sample(site=None)[source]

Initialize to a random sample from the prior.

init_to_median(site=None, num_samples=15, *, fallback: Optional[Callable] = <function init_to_feasible>)[source]

Initialize to the prior median; fallback to fallback (defaults to init_to_feasible()) if mean is undefined.

Parameters

fallback (callable) – Fallback init strategy, for sites not specified in values.

Raises

ValueError – If fallback=None and no value for a site is given in values.

init_to_mean(site=None, *, fallback: Optional[Callable] = <function init_to_median>)[source]

Initialize to the prior mean; fallback to fallback (defaults to init_to_median()) if mean is undefined.

Parameters

fallback (callable) – Fallback init strategy, for sites not specified in values.

Raises

ValueError – If fallback=None and no value for a site is given in values.

init_to_uniform(site: Optional[dict] = None, radius: float = 2.0)[source]

Initialize to a random point in the area (-radius, radius) of unconstrained domain.

Parameters

radius (float) – specifies the range to draw an initial point in the unconstrained domain.

init_to_value(site: Optional[dict] = None, values: dict = {}, *, fallback: Optional[Callable] = <function init_to_uniform>)[source]

Initialize to the value specified in values. Fallback to fallback (defaults to init_to_uniform()) strategy for sites not appearing in values.

Parameters
  • values (dict) – dictionary of initial values keyed by site name.

  • fallback (callable) – Fallback init strategy, for sites not specified in values.

Raises

ValueError – If fallback=None and no value for a site is given in values.

init_to_generated(site=None, generate=<function <lambda>>)[source]

Initialize to another initialization strategy returned by the callback generate which is called once per model execution.

This is like init_to_value() but can produce different (e.g. random) values once per model execution. For example to generate values and return init_to_value you could define:

def generate():
    values = {"x": torch.randn(100), "y": torch.rand(5)}
    return init_to_value(values=values)

my_init_fn = init_to_generated(generate=generate)
Parameters

generate (callable) – A callable returning another initialization function, e.g. returning an init_to_value(values={...}) populated with a dictionary of random samples.

class InitMessenger(init_fn)[source]

Bases: pyro.poutine.messenger.Messenger

Initializes a site by replacing .sample() calls with values drawn from an initialization strategy. This is mainly for internal use by autoguide classes.

Parameters

init_fn (callable) – An initialization function.