Reparameterizers¶
The pyro.infer.reparam
module contains reparameterization strategies for
the pyro.poutine.handlers.reparam()
effect. These are useful for altering
geometry of a poorly-conditioned parameter space to make the posterior better
shaped. These can be used with a variety of inference algorithms, e.g.
Auto*Normal
guides and MCMC.
- class ReparamMessage[source]¶
-
- fn: Callable¶
- value: Optional[torch.Tensor]¶
- class Reparam[source]¶
Abstract base class for reparameterizers.
Derived classes should implement
apply()
.- apply(msg: pyro.infer.reparam.reparam.ReparamMessage) pyro.infer.reparam.reparam.ReparamResult [source]¶
Abstract method to apply reparameterizer.
- Parameters
name (dict) – A simplified Pyro message with fields: -
name: str
the sample site’s name -fn: Callable
a distribution -value: Optional[torch.Tensor]
an observed or initial value -is_observed: bool
whethervalue
is an observation- Returns
A simplified Pyro message with fields
fn
,value
, andis_observed
.- Return type
Automatic Strategies¶
These reparametrization strategies are registered with
register_reparam_strategy()
and are
accessed by name via poutine.reparam(config=name_of_strategy)
.
See reparam()
for usage.
- class Strategy[source]¶
Bases:
abc.ABC
Abstract base class for reparametrizer configuration strategies.
Derived classes must implement the
configure()
method.- Variables
config (dict) – A dictionary configuration. This will be populated the first time the model is run. Thereafter it can be used as an argument to
poutine.reparam(config=___)
.
- abstract configure(msg: dict) Optional[pyro.infer.reparam.reparam.Reparam] [source]¶
Inputs a sample site and returns either None or a
Reparam
instance.This will be called only on the first model execution; subsequent executions will use the reparametrizer stored in
self.config
.- Parameters
msg (dict) – A sample site to possibly reparametrize.
- Returns
An optional reparametrizer instance.
- class MinimalReparam[source]¶
Bases:
pyro.infer.reparam.strategies.Strategy
Minimal reparametrization strategy that reparametrizes only those sites that would otherwise lead to error, e.g.
Stable
andProjectedNormal
random variables.Example:
@MinimalReparam() def model(...): ...
which is equivalent to:
@poutine.reparam(config=MinimalReparam()) def model(...): ...
- configure(msg: dict) Optional[pyro.infer.reparam.reparam.Reparam] [source]¶
- class AutoReparam(*, centered: Optional[float] = None)[source]¶
Bases:
pyro.infer.reparam.strategies.Strategy
Applies a recommended set of reparametrizers. These currently include:
MinimalReparam
,TransformReparam
, a fully-learnableLocScaleReparam
, andGumbelSoftmaxReparam
.Example:
@AutoReparam() def model(...): ...
which is equivalent to:
@poutine.reparam(config=AutoReparam()) def model(...): ...
Warning
This strategy may change behavior across Pyro releases. To inspect or save a given behavior, extract the
.config
dict after running the model at least once.- Parameters
centered – Optional centering parameter for
LocScaleReparam
reparametrizers. If None (default), centering will be learned. If a float in[0.0,1.0]
, then a fixed centering. To completely decenter (e.g. in MCMC), set to 0.0.
- configure(msg: dict) Optional[pyro.infer.reparam.reparam.Reparam] [source]¶
Conjugate Updating¶
- class ConjugateReparam(guide)[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
EXPERIMENTAL Reparameterize to a conjugate updated distribution.
This updates a prior distribution
fn
using theconjugate_update()
method. The guide may be either a distribution object or a callable inputting model*args,**kwargs
and returning a distribution object. The guide may be approximate or learned.For example consider the model and naive variational guide:
total = torch.tensor(10.) count = torch.tensor(2.) def model(): prob = pyro.sample("prob", dist.Beta(0.5, 1.5)) pyro.sample("count", dist.Binomial(total, prob), obs=count) guide = AutoDiagonalNormal(model) # learns the posterior over prob
Instead of using this learned guide, we can hand-compute the conjugate posterior distribution over “prob”, and then use a simpler guide during inference, in this case an empty guide:
reparam_model = poutine.reparam(model, { "prob": ConjugateReparam(dist.Beta(1 + count, 1 + total - count)) }) def reparam_guide(): pass # nothing remains to be modeled!
- Parameters
guide (Distribution or callable) – A likelihood distribution or a callable returning a guide distribution. Only a few distributions are supported, depending on the prior distribution’s
conjugate_update()
implementation.
Loc-Scale Decentering¶
- class LocScaleReparam(centered=None, shape_params=None)[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Generic decentering reparameterizer [1] for latent variables parameterized by
loc
andscale
(and possibly additionalshape_params
).This reparameterization works only for latent variables, not likelihoods.
- [1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
“Automatic Reparameterisation of Probabilistic Programs” https://arxiv.org/pdf/1906.03028.pdf
- Parameters
centered (float) – optional centered parameter. If None (default) learn a per-site per-element centering parameter in
[0,1]
. If 0, fully decenter the distribution; if 1, preserve the centered distribution unchanged.shape_params (tuple or list) – Optional list of additional parameter names to copy unchanged from the centered to decentered distribution. If absent, all params in a distributions
.arg_constraints
will be copied.
Gumbel-Softmax¶
- class GumbelSoftmaxReparam[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Reparametrizer for
RelaxedOneHotCategorical
latent variables.This is useful for transforming multimodal posteriors to unimodal posteriors. Note this increases the latent dimension by 1 per event.
This reparameterization works only for latent variables, not likelihoods.
Transformed Distributions¶
- class TransformReparam[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Reparameterizer for
pyro.distributions.torch.TransformedDistribution
latent variables.This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior has simple shape in the space of
base_dist
.This reparameterization works only for latent variables, not likelihoods.
Discrete Cosine Transform¶
- class DiscreteCosineReparam(dim=- 1, smooth=0.0, *, experimental_allow_batch=False)[source]¶
Bases:
pyro.infer.reparam.unit_jacobian.UnitJacobianReparam
Discrete Cosine reparameterizer, using a
DiscreteCosineTransform
.This is useful for sequential models where coupling along a time-like axis (e.g. a banded precision matrix) introduces long-range correlation. This reparameterizes to a frequency-domain representation where posterior covariance should be closer to diagonal, thereby improving the accuracy of diagonal guides in SVI and improving the effectiveness of a diagonal mass matrix in HMC.
When reparameterizing variables that are approximately continuous along the time dimension, set
smooth=1
. For variables that are approximately continuously differentiable along the time axis, setsmooth=2
.This reparameterization works only for latent variables, not likelihoods.
- Parameters
dim (int) – Dimension along which to transform. Must be negative. This is an absolute dim counting from the right.
smooth (float) – Smoothing parameter. When 0, this transforms white noise to white noise; when 1 this transforms Brownian noise to to white noise; when -1 this transforms violet noise to white noise; etc. Any real number is allowed. https://en.wikipedia.org/wiki/Colors_of_noise.
experimental_allow_batch (bool) – EXPERIMENTAL allow coupling across a batch dimension. The targeted batch dimension and all batch dimensions to the right will be converted to event dimensions. Defaults to False.
Haar Transform¶
- class HaarReparam(dim=- 1, flip=False, *, experimental_allow_batch=False)[source]¶
Bases:
pyro.infer.reparam.unit_jacobian.UnitJacobianReparam
Haar wavelet reparameterizer, using a
HaarTransform
.This is useful for sequential models where coupling along a time-like axis (e.g. a banded precision matrix) introduces long-range correlation. This reparameterizes to a frequency-domain representation where posterior covariance should be closer to diagonal, thereby improving the accuracy of diagonal guides in SVI and improving the effectiveness of a diagonal mass matrix in HMC.
This reparameterization works only for latent variables, not likelihoods.
- Parameters
dim (int) – Dimension along which to transform. Must be negative. This is an absolute dim counting from the right.
flip (bool) – Whether to flip the time axis before applying the Haar transform. Defaults to false.
experimental_allow_batch (bool) – EXPERIMENTAL allow coupling across a batch dimension. The targeted batch dimension and all batch dimensions to the right will be converted to event dimensions. Defaults to False.
Unit Jacobian Transforms¶
- class UnitJacobianReparam(transform, suffix='transformed', *, experimental_allow_batch=False)[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Reparameterizer for
Transform
objects whose Jacobian determinant is one.- Parameters
transform (Transform) – A transform whose Jacobian has determinant 1.
suffix (str) – A suffix to append to the transformed site.
experimental_allow_batch (bool) – EXPERIMENTAL allow coupling across a batch dimension. The targeted batch dimension and all batch dimensions to the right will be converted to event dimensions. Defaults to False.
StudentT Distributions¶
- class StudentTReparam[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Auxiliary variable reparameterizer for
StudentT
random variables.This is useful in combination with
LinearHMMReparam
because it allows StudentT processes to be treated as conditionally Gaussian processes, permitting cheap inference viaGaussianHMM
.This reparameterizes a
StudentT
by introducing an auxiliaryGamma
variable conditioned on which the result isNormal
.
Stable Distributions¶
- class LatentStableReparam[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Auxiliary variable reparameterizer for
Stable
latent variables.This is useful in inference of latent
Stable
variables because thelog_prob()
is not implemented.This uses the Chambers-Mallows-Stuck method [1], creating a pair of parameter-free auxiliary distributions (
Uniform(-pi/2,pi/2)
andExponential(1)
) with well-defined.log_prob()
methods, thereby permitting use of reparameterized stable distributions in likelihood-based inference algorithms like SVI and MCMC.This reparameterization works only for latent variables, not likelihoods. For likelihood-compatible reparameterization see
SymmetricStableReparam
orStableReparam
.- [1] J.P. Nolan (2017).
Stable Distributions: Models for Heavy Tailed Data. https://edspace.american.edu/jpnolan/wp-content/uploads/sites/1720/2020/09/Chap1.pdf
- class SymmetricStableReparam[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Auxiliary variable reparameterizer for symmetric
Stable
random variables (i.e. those for whichskew=0
).This is useful in inference of symmetric
Stable
variables because thelog_prob()
is not implemented.This reparameterizes a symmetric
Stable
random variable as a totally-skewed (skew=1
)Stable
scale mixture ofNormal
random variables. See Proposition 3. of [1] (but note we differ sinceStable
uses Nolan’s continuous S0 parameterization).- [1] Alvaro Cartea and Sam Howison (2009)
“Option Pricing with Levy-Stable Processes” https://pdfs.semanticscholar.org/4d66/c91b136b2a38117dd16c2693679f5341c616.pdf
- class StableReparam[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Auxiliary variable reparameterizer for arbitrary
Stable
random variables.This is useful in inference of non-symmetric
Stable
variables because thelog_prob()
is not implemented.This reparameterizes a
Stable
random variable as sum of two other stable random variables, one symmetric and the other totally skewed (applying Property 2.3.a of [1]). The totally skewed variable is sampled as inLatentStableReparam
, and the symmetric variable is decomposed as inSymmetricStableReparam
.- [1] V. M. Zolotarev (1986)
“One-dimensional stable distributions”
Projected Normal Distributions¶
- class ProjectedNormalReparam[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Reparametrizer for
ProjectedNormal
latent variables.This reparameterization works only for latent variables, not likelihoods.
Hidden Markov Models¶
- class LinearHMMReparam(init=None, trans=None, obs=None)[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Auxiliary variable reparameterizer for
LinearHMM
random variables.This defers to component reparameterizers to create auxiliary random variables conditioned on which the process becomes a
GaussianHMM
. If theobservation_dist
is aTransformedDistribution
this reorders those transforms so that the result is aTransformedDistribution
ofGaussianHMM
.This is useful for training the parameters of a
LinearHMM
distribution, whoselog_prob()
method is undefined. To perform inference in the presence of non-Gaussian factors such asStable()
,StudentT()
orLogNormal()
, configure withStudentTReparam
,StableReparam
,SymmetricStableReparam
, etc. component reparameterizers forinit
,trans
, andscale
. For example:hmm = LinearHMM( init_dist=Stable(1,0,1,0).expand([2]).to_event(1), trans_matrix=torch.eye(2), trans_dist=MultivariateNormal(torch.zeros(2), torch.eye(2)), obs_matrix=torch.eye(2), obs_dist=TransformedDistribution( Stable(1.5,-0.5,1.0).expand([2]).to_event(1), ExpTransform())) rep = LinearHMMReparam(init=SymmetricStableReparam(), obs=StableReparam()) with poutine.reparam(config={"hmm": rep}): pyro.sample("hmm", hmm, obs=data)
- Parameters
Site Splitting¶
- same_support(fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args)[source]¶
Returns support of the fn distribution. Used in
SplitReparam
in order to determine the support of the split value.- Parameters
fn – distribution class
- Returns
distribution support
- real_support(fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args)[source]¶
Returns real support with same event dimension as that of the fn distribution. Used in
SplitReparam
in order to determine the support of the split value.- Parameters
fn – distribution class
- Returns
distribution support
- default_support(fn: pyro.distributions.torch_distribution.TorchDistributionMixin, slice, dim)[source]¶
Returns support of the fn distribution, corrected for split stacking and concatenation transforms. Used in
SplitReparam
in order to determine the support of the split value.- Parameters
fn – distribution class
slice – slice for which to return support
dim – dimension for which to return support
- Returns
distribution support
- class SplitReparam(sections, dim, support_fn=<function default_support>)[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Reparameterizer to split a random variable along a dimension, similar to
torch.split()
.This is useful for treating different parts of a tensor with different reparameterizers or inference methods. For example when performing HMC inference on a time series, you can first apply
DiscreteCosineReparam
orHaarReparam
, then applySplitReparam
to split into low-frequency and high-frequency components, and finally add the low-frequency components to thefull_mass
matrix together with globals.- Parameters
sections – Size of a single chunk or list of sizes for each chunk.
dim (int) – Dimension along which to split. Defaults to -1.
support_fn (callable) – Function which derives the split support from the site’s sampling function, split size, and split dimension. Default is
default_support()
which correctly handles stacking and concatenation transforms. Other options aresame_support()
which returns the same support as that of the sampling function, andreal_support()
which returns a real support.
- Type
Neural Transport¶
- class NeuTraReparam(guide)[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Neural Transport reparameterizer [1] of multiple latent variables.
This uses a trained
AutoContinuous
guide to alter the geometry of a model, typically for use e.g. in MCMC. Example usage:# Step 1. Train a guide guide = AutoIAFNormal(model) svi = SVI(model, guide, ...) # ...train the guide... # Step 2. Use trained guide in NeuTra MCMC neutra = NeuTraReparam(guide) model = poutine.reparam(model, config=lambda _: neutra) nuts = NUTS(model) # ...now use the model in HMC or NUTS...
This reparameterization works only for latent variables, not likelihoods. Note that all sites must share a single common
NeuTraReparam
instance, and that the model must have static structure.- [1] Hoffman, M. et al. (2019)
“NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport” https://arxiv.org/abs/1903.03704
- Parameters
guide (AutoContinuous) – A trained guide.
- transform_sample(latent)[source]¶
Given latent samples from the warped posterior (with a possible batch dimension), return a dict of samples from the latent sites in the model.
- Parameters
latent – sample from the warped posterior (possibly batched). Note that the batch dimension must not collide with plate dimensions in the model, i.e. any batch dims d < - max_plate_nesting.
- Returns
a dict of samples keyed by latent sites in the model.
- Return type
Structured Preconditioning¶
- class StructuredReparam(guide: pyro.infer.autoguide.structured.AutoStructured)[source]¶
Bases:
pyro.infer.reparam.reparam.Reparam
Preconditioning reparameterizer of multiple latent variables.
This uses a trained
AutoStructured
guide to alter the geometry of a model, typically for use e.g. in MCMC. Example usage:# Step 1. Train a guide guide = AutoStructured(model, ...) svi = SVI(model, guide, ...) # ...train the guide... # Step 2. Use trained guide in preconditioned MCMC model = StructuredReparam(guide).reparam(model) nuts = NUTS(model) # ...now use the model in HMC or NUTS...
This reparameterization works only for latent variables, not likelihoods. Note that all sites must share a single common
StructuredReparam
instance, and that the model must have static structure.Note
This can be seen as a restricted structured version of
NeuTraReparam
[1] combined withpoutine.condition
on MAP-estimated sites (the NeuTra transform is an exact reparameterizer, but the conditioning to point estimates introduces model approximation).- [1] Hoffman, M. et al. (2019)
“NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport” https://arxiv.org/abs/1903.03704
- Parameters
guide (AutoStructured) – A trained guide.
- transform_samples(aux_samples, save_params=None)[source]¶
Given latent samples from the warped posterior (with a possible batch dimension), return a dict of samples from the latent sites in the model.
- Parameters
aux_samples (dict) – Dict site name to tensor value for each latent auxiliary site (or if
save_params
is specifiec, then for only those latent auxiliary sites needed to compute requested params).save_params (list) – An optional list of site names to save. This is useful in models with large nuisance variables. Defaults to None, saving all params.
- Returns
a dict of samples keyed by latent sites in the model.
- Return type