Reparameterizers

The pyro.infer.reparam module contains reparameterization strategies for the pyro.poutine.handlers.reparam() effect. These are useful for altering geometry of a poorly-conditioned parameter space to make the posterior better shaped. These can be used with a variety of inference algorithms, e.g. Auto*Normal guides and MCMC.

class ReparamMessage[source]
name: str
fn: Callable
value: Optional[torch.Tensor]
is_observed: Optional[bool]
class ReparamResult[source]
fn: Callable
value: Optional[torch.Tensor]
is_observed: bool
class Reparam[source]

Abstract base class for reparameterizers.

Derived classes should implement apply().

apply(msg: pyro.infer.reparam.reparam.ReparamMessage) pyro.infer.reparam.reparam.ReparamResult[source]

Abstract method to apply reparameterizer.

Parameters

name (dict) – A simplified Pyro message with fields: - name: str the sample site’s name - fn: Callable a distribution - value: Optional[torch.Tensor] an observed or initial value - is_observed: bool whether value is an observation

Returns

A simplified Pyro message with fields fn, value, and is_observed.

Return type

dict

__call__(name, fn, obs)[source]

DEPRECATED. Subclasses should implement apply() instead. This will be removed in a future release.

Automatic Strategies

These reparametrization strategies are registered with register_reparam_strategy() and are accessed by name via poutine.reparam(config=name_of_strategy) . See reparam() for usage.

class Strategy[source]

Bases: abc.ABC

Abstract base class for reparametrizer configuration strategies.

Derived classes must implement the configure() method.

Variables

config (dict) – A dictionary configuration. This will be populated the first time the model is run. Thereafter it can be used as an argument to poutine.reparam(config=___).

abstract configure(msg: dict) Optional[pyro.infer.reparam.reparam.Reparam][source]

Inputs a sample site and returns either None or a Reparam instance.

This will be called only on the first model execution; subsequent executions will use the reparametrizer stored in self.config.

Parameters

msg (dict) – A sample site to possibly reparametrize.

Returns

An optional reparametrizer instance.

__call__(msg_or_fn: Union[dict, Callable])[source]

Strategies can be used as decorators to reparametrize a model.

Parameters

msg_or_fn – Public use: a model to be decorated. (Internal use: a site to be configured for reparametrization).

class MinimalReparam[source]

Bases: pyro.infer.reparam.strategies.Strategy

Minimal reparametrization strategy that reparametrizes only those sites that would otherwise lead to error, e.g. Stable and ProjectedNormal random variables.

Example:

@MinimalReparam()
def model(...):
    ...

which is equivalent to:

@poutine.reparam(config=MinimalReparam())
def model(...):
    ...
configure(msg: dict) Optional[pyro.infer.reparam.reparam.Reparam][source]
class AutoReparam(*, centered: Optional[float] = None)[source]

Bases: pyro.infer.reparam.strategies.Strategy

Applies a recommended set of reparametrizers. These currently include: MinimalReparam, TransformReparam, a fully-learnable LocScaleReparam, and GumbelSoftmaxReparam.

Example:

@AutoReparam()
def model(...):
    ...

which is equivalent to:

@poutine.reparam(config=AutoReparam())
def model(...):
    ...

Warning

This strategy may change behavior across Pyro releases. To inspect or save a given behavior, extract the .config dict after running the model at least once.

Parameters

centered – Optional centering parameter for LocScaleReparam reparametrizers. If None (default), centering will be learned. If a float in [0.0,1.0], then a fixed centering. To completely decenter (e.g. in MCMC), set to 0.0.

configure(msg: dict) Optional[pyro.infer.reparam.reparam.Reparam][source]

Conjugate Updating

class ConjugateReparam(guide)[source]

Bases: pyro.infer.reparam.reparam.Reparam

EXPERIMENTAL Reparameterize to a conjugate updated distribution.

This updates a prior distribution fn using the conjugate_update() method. The guide may be either a distribution object or a callable inputting model *args,**kwargs and returning a distribution object. The guide may be approximate or learned.

For example consider the model and naive variational guide:

total = torch.tensor(10.)
count = torch.tensor(2.)

def model():
    prob = pyro.sample("prob", dist.Beta(0.5, 1.5))
    pyro.sample("count", dist.Binomial(total, prob), obs=count)

guide = AutoDiagonalNormal(model)  # learns the posterior over prob

Instead of using this learned guide, we can hand-compute the conjugate posterior distribution over “prob”, and then use a simpler guide during inference, in this case an empty guide:

reparam_model = poutine.reparam(model, {
    "prob": ConjugateReparam(dist.Beta(1 + count, 1 + total - count))
})

def reparam_guide():
    pass  # nothing remains to be modeled!
Parameters

guide (Distribution or callable) – A likelihood distribution or a callable returning a guide distribution. Only a few distributions are supported, depending on the prior distribution’s conjugate_update() implementation.

apply(msg)[source]

Loc-Scale Decentering

class LocScaleReparam(centered=None, shape_params=None)[source]

Bases: pyro.infer.reparam.reparam.Reparam

Generic decentering reparameterizer [1] for latent variables parameterized by loc and scale (and possibly additional shape_params).

This reparameterization works only for latent variables, not likelihoods.

[1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

“Automatic Reparameterisation of Probabilistic Programs” https://arxiv.org/pdf/1906.03028.pdf

Parameters
  • centered (float) – optional centered parameter. If None (default) learn a per-site per-element centering parameter in [0,1]. If 0, fully decenter the distribution; if 1, preserve the centered distribution unchanged.

  • shape_params (tuple or list) – Optional list of additional parameter names to copy unchanged from the centered to decentered distribution. If absent, all params in a distributions .arg_constraints will be copied.

apply(msg)[source]

Gumbel-Softmax

class GumbelSoftmaxReparam[source]

Bases: pyro.infer.reparam.reparam.Reparam

Reparametrizer for RelaxedOneHotCategorical latent variables.

This is useful for transforming multimodal posteriors to unimodal posteriors. Note this increases the latent dimension by 1 per event.

This reparameterization works only for latent variables, not likelihoods.

apply(msg)[source]

Transformed Distributions

class TransformReparam[source]

Bases: pyro.infer.reparam.reparam.Reparam

Reparameterizer for pyro.distributions.torch.TransformedDistribution latent variables.

This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior has simple shape in the space of base_dist.

This reparameterization works only for latent variables, not likelihoods.

apply(msg)[source]

Discrete Cosine Transform

class DiscreteCosineReparam(dim=- 1, smooth=0.0, *, experimental_allow_batch=False)[source]

Bases: pyro.infer.reparam.unit_jacobian.UnitJacobianReparam

Discrete Cosine reparameterizer, using a DiscreteCosineTransform .

This is useful for sequential models where coupling along a time-like axis (e.g. a banded precision matrix) introduces long-range correlation. This reparameterizes to a frequency-domain representation where posterior covariance should be closer to diagonal, thereby improving the accuracy of diagonal guides in SVI and improving the effectiveness of a diagonal mass matrix in HMC.

When reparameterizing variables that are approximately continuous along the time dimension, set smooth=1. For variables that are approximately continuously differentiable along the time axis, set smooth=2.

This reparameterization works only for latent variables, not likelihoods.

Parameters
  • dim (int) – Dimension along which to transform. Must be negative. This is an absolute dim counting from the right.

  • smooth (float) – Smoothing parameter. When 0, this transforms white noise to white noise; when 1 this transforms Brownian noise to to white noise; when -1 this transforms violet noise to white noise; etc. Any real number is allowed. https://en.wikipedia.org/wiki/Colors_of_noise.

  • experimental_allow_batch (bool) – EXPERIMENTAL allow coupling across a batch dimension. The targeted batch dimension and all batch dimensions to the right will be converted to event dimensions. Defaults to False.

Haar Transform

class HaarReparam(dim=- 1, flip=False, *, experimental_allow_batch=False)[source]

Bases: pyro.infer.reparam.unit_jacobian.UnitJacobianReparam

Haar wavelet reparameterizer, using a HaarTransform.

This is useful for sequential models where coupling along a time-like axis (e.g. a banded precision matrix) introduces long-range correlation. This reparameterizes to a frequency-domain representation where posterior covariance should be closer to diagonal, thereby improving the accuracy of diagonal guides in SVI and improving the effectiveness of a diagonal mass matrix in HMC.

This reparameterization works only for latent variables, not likelihoods.

Parameters
  • dim (int) – Dimension along which to transform. Must be negative. This is an absolute dim counting from the right.

  • flip (bool) – Whether to flip the time axis before applying the Haar transform. Defaults to false.

  • experimental_allow_batch (bool) – EXPERIMENTAL allow coupling across a batch dimension. The targeted batch dimension and all batch dimensions to the right will be converted to event dimensions. Defaults to False.

Unit Jacobian Transforms

class UnitJacobianReparam(transform, suffix='transformed', *, experimental_allow_batch=False)[source]

Bases: pyro.infer.reparam.reparam.Reparam

Reparameterizer for Transform objects whose Jacobian determinant is one.

Parameters
  • transform (Transform) – A transform whose Jacobian has determinant 1.

  • suffix (str) – A suffix to append to the transformed site.

  • experimental_allow_batch (bool) – EXPERIMENTAL allow coupling across a batch dimension. The targeted batch dimension and all batch dimensions to the right will be converted to event dimensions. Defaults to False.

apply(msg)[source]

StudentT Distributions

class StudentTReparam[source]

Bases: pyro.infer.reparam.reparam.Reparam

Auxiliary variable reparameterizer for StudentT random variables.

This is useful in combination with LinearHMMReparam because it allows StudentT processes to be treated as conditionally Gaussian processes, permitting cheap inference via GaussianHMM .

This reparameterizes a StudentT by introducing an auxiliary Gamma variable conditioned on which the result is Normal .

apply(msg)[source]

Stable Distributions

class LatentStableReparam[source]

Bases: pyro.infer.reparam.reparam.Reparam

Auxiliary variable reparameterizer for Stable latent variables.

This is useful in inference of latent Stable variables because the log_prob() is not implemented.

This uses the Chambers-Mallows-Stuck method [1], creating a pair of parameter-free auxiliary distributions (Uniform(-pi/2,pi/2) and Exponential(1)) with well-defined .log_prob() methods, thereby permitting use of reparameterized stable distributions in likelihood-based inference algorithms like SVI and MCMC.

This reparameterization works only for latent variables, not likelihoods. For likelihood-compatible reparameterization see SymmetricStableReparam or StableReparam .

[1] J.P. Nolan (2017).

Stable Distributions: Models for Heavy Tailed Data. https://edspace.american.edu/jpnolan/wp-content/uploads/sites/1720/2020/09/Chap1.pdf

apply(msg)[source]
class SymmetricStableReparam[source]

Bases: pyro.infer.reparam.reparam.Reparam

Auxiliary variable reparameterizer for symmetric Stable random variables (i.e. those for which skew=0).

This is useful in inference of symmetric Stable variables because the log_prob() is not implemented.

This reparameterizes a symmetric Stable random variable as a totally-skewed (skew=1) Stable scale mixture of Normal random variables. See Proposition 3. of [1] (but note we differ since Stable uses Nolan’s continuous S0 parameterization).

[1] Alvaro Cartea and Sam Howison (2009)

“Option Pricing with Levy-Stable Processes” https://pdfs.semanticscholar.org/4d66/c91b136b2a38117dd16c2693679f5341c616.pdf

apply(msg)[source]
class StableReparam[source]

Bases: pyro.infer.reparam.reparam.Reparam

Auxiliary variable reparameterizer for arbitrary Stable random variables.

This is useful in inference of non-symmetric Stable variables because the log_prob() is not implemented.

This reparameterizes a Stable random variable as sum of two other stable random variables, one symmetric and the other totally skewed (applying Property 2.3.a of [1]). The totally skewed variable is sampled as in LatentStableReparam , and the symmetric variable is decomposed as in SymmetricStableReparam .

[1] V. M. Zolotarev (1986)

“One-dimensional stable distributions”

apply(msg)[source]

Projected Normal Distributions

class ProjectedNormalReparam[source]

Bases: pyro.infer.reparam.reparam.Reparam

Reparametrizer for ProjectedNormal latent variables.

This reparameterization works only for latent variables, not likelihoods.

apply(msg)[source]

Hidden Markov Models

class LinearHMMReparam(init=None, trans=None, obs=None)[source]

Bases: pyro.infer.reparam.reparam.Reparam

Auxiliary variable reparameterizer for LinearHMM random variables.

This defers to component reparameterizers to create auxiliary random variables conditioned on which the process becomes a GaussianHMM . If the observation_dist is a TransformedDistribution this reorders those transforms so that the result is a TransformedDistribution of GaussianHMM .

This is useful for training the parameters of a LinearHMM distribution, whose log_prob() method is undefined. To perform inference in the presence of non-Gaussian factors such as Stable(), StudentT() or LogNormal() , configure with StudentTReparam , StableReparam , SymmetricStableReparam , etc. component reparameterizers for init, trans, and scale. For example:

hmm = LinearHMM(
    init_dist=Stable(1,0,1,0).expand([2]).to_event(1),
    trans_matrix=torch.eye(2),
    trans_dist=MultivariateNormal(torch.zeros(2), torch.eye(2)),
    obs_matrix=torch.eye(2),
    obs_dist=TransformedDistribution(
        Stable(1.5,-0.5,1.0).expand([2]).to_event(1),
        ExpTransform()))

rep = LinearHMMReparam(init=SymmetricStableReparam(),
                       obs=StableReparam())

with poutine.reparam(config={"hmm": rep}):
    pyro.sample("hmm", hmm, obs=data)
Parameters
  • init (Reparam) – Optional reparameterizer for the initial distribution.

  • trans (Reparam) – Optional reparameterizer for the transition distribution.

  • obs (Reparam) – Optional reparameterizer for the observation distribution.

apply(msg)[source]

Site Splitting

same_support(fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args)[source]

Returns support of the fn distribution. Used in SplitReparam in order to determine the support of the split value.

Parameters

fn – distribution class

Returns

distribution support

real_support(fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args)[source]

Returns real support with same event dimension as that of the fn distribution. Used in SplitReparam in order to determine the support of the split value.

Parameters

fn – distribution class

Returns

distribution support

default_support(fn: pyro.distributions.torch_distribution.TorchDistributionMixin, slice, dim)[source]

Returns support of the fn distribution, corrected for split stacking and concatenation transforms. Used in SplitReparam in order to determine the support of the split value.

Parameters
  • fn – distribution class

  • slice – slice for which to return support

  • dim – dimension for which to return support

Returns

distribution support

class SplitReparam(sections, dim, support_fn=<function default_support>)[source]

Bases: pyro.infer.reparam.reparam.Reparam

Reparameterizer to split a random variable along a dimension, similar to torch.split().

This is useful for treating different parts of a tensor with different reparameterizers or inference methods. For example when performing HMC inference on a time series, you can first apply DiscreteCosineReparam or HaarReparam, then apply SplitReparam to split into low-frequency and high-frequency components, and finally add the low-frequency components to the full_mass matrix together with globals.

Parameters
  • sections – Size of a single chunk or list of sizes for each chunk.

  • dim (int) – Dimension along which to split. Defaults to -1.

  • support_fn (callable) – Function which derives the split support from the site’s sampling function, split size, and split dimension. Default is default_support() which correctly handles stacking and concatenation transforms. Other options are same_support() which returns the same support as that of the sampling function, and real_support() which returns a real support.

Type

list(int)

apply(msg)[source]

Neural Transport

class NeuTraReparam(guide)[source]

Bases: pyro.infer.reparam.reparam.Reparam

Neural Transport reparameterizer [1] of multiple latent variables.

This uses a trained AutoContinuous guide to alter the geometry of a model, typically for use e.g. in MCMC. Example usage:

# Step 1. Train a guide
guide = AutoIAFNormal(model)
svi = SVI(model, guide, ...)
# ...train the guide...

# Step 2. Use trained guide in NeuTra MCMC
neutra = NeuTraReparam(guide)
model = poutine.reparam(model, config=lambda _: neutra)
nuts = NUTS(model)
# ...now use the model in HMC or NUTS...

This reparameterization works only for latent variables, not likelihoods. Note that all sites must share a single common NeuTraReparam instance, and that the model must have static structure.

[1] Hoffman, M. et al. (2019)

“NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport” https://arxiv.org/abs/1903.03704

Parameters

guide (AutoContinuous) – A trained guide.

reparam(fn=None)[source]
apply(msg)[source]
transform_sample(latent)[source]

Given latent samples from the warped posterior (with a possible batch dimension), return a dict of samples from the latent sites in the model.

Parameters

latent – sample from the warped posterior (possibly batched). Note that the batch dimension must not collide with plate dimensions in the model, i.e. any batch dims d < - max_plate_nesting.

Returns

a dict of samples keyed by latent sites in the model.

Return type

dict

Structured Preconditioning

class StructuredReparam(guide: pyro.infer.autoguide.structured.AutoStructured)[source]

Bases: pyro.infer.reparam.reparam.Reparam

Preconditioning reparameterizer of multiple latent variables.

This uses a trained AutoStructured guide to alter the geometry of a model, typically for use e.g. in MCMC. Example usage:

# Step 1. Train a guide
guide = AutoStructured(model, ...)
svi = SVI(model, guide, ...)
# ...train the guide...

# Step 2. Use trained guide in preconditioned MCMC
model = StructuredReparam(guide).reparam(model)
nuts = NUTS(model)
# ...now use the model in HMC or NUTS...

This reparameterization works only for latent variables, not likelihoods. Note that all sites must share a single common StructuredReparam instance, and that the model must have static structure.

Note

This can be seen as a restricted structured version of NeuTraReparam [1] combined with poutine.condition on MAP-estimated sites (the NeuTra transform is an exact reparameterizer, but the conditioning to point estimates introduces model approximation).

[1] Hoffman, M. et al. (2019)

“NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport” https://arxiv.org/abs/1903.03704

Parameters

guide (AutoStructured) – A trained guide.

reparam(fn=None)[source]
apply(msg)[source]
transform_samples(aux_samples, save_params=None)[source]

Given latent samples from the warped posterior (with a possible batch dimension), return a dict of samples from the latent sites in the model.

Parameters
  • aux_samples (dict) – Dict site name to tensor value for each latent auxiliary site (or if save_params is specifiec, then for only those latent auxiliary sites needed to compute requested params).

  • save_params (list) – An optional list of site names to save. This is useful in models with large nuisance variables. Defaults to None, saving all params.

Returns

a dict of samples keyed by latent sites in the model.

Return type

dict