Reparameterizers

The pyro.infer.reparam module contains reparameterization strategies for the pyro.poutine.handlers.reparam() effect. These are useful for altering geometry of a poorly-conditioned parameter space to make the posterior better shaped. These can be used with a variety of inference algorithms, e.g. Auto*Normal guides and MCMC.

class Reparam[source]

Base class for reparameterizers.

__call__(name, fn, obs)[source]
Parameters:
Returns:

A pair (new_fn, value).

Loc-Scale Decentering

class LocScaleReparam(centered=None, shape_params=())[source]

Bases: pyro.infer.reparam.reparam.Reparam

Generic decentering reparameterizer [1] for latent variables parameterized by loc and scale (and possibly additional shape_params).

This reparameterization works only for latent variables, not likelihoods.

[1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
“Automatic Reparameterisation of Probabilistic Programs” https://arxiv.org/pdf/1906.03028.pdf
Parameters:
  • centered (float) – optional centered parameter. If None (default) learn a per-site per-element centering parameter in [0,1]. If 0, fully decenter the distribution; if 1, preserve the centered distribution unchanged.
  • shape_params (tuple or list) – list of additional parameter names to copy unchanged from the centered to decentered distribution.
__call__(name, fn, obs)[source]

Transformed Distributions

class TransformReparam[source]

Bases: pyro.infer.reparam.reparam.Reparam

Reparameterizer for pyro.distributions.torch.TransformedDistribution latent variables.

This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior has simple shape in the space of base_dist.

This reparameterization works only for latent variables, not likelihoods.

__call__(name, fn, obs)[source]

Discrete Cosine Transform

class DiscreteCosineReparam(dim=-1)[source]

Bases: pyro.infer.reparam.reparam.Reparam

Discrete Cosine reparamterizer, using a DiscreteCosineTransform .

This is useful for sequential models where coupling along a time-like axis (e.g. a banded precision matrix) introduces long-range correlation. This reparameterizes to a frequency-domain represetation where posterior covariance should be closer to diagonal, thereby improving the accuracy of diagonal guides in SVI and improving the effectiveness of a diagonal mass matrix in HMC.

This reparameterization works only for latent variables, not likelihoods.

Parameters:dim (int) – Dimension along which to transform. Must be negative. This is an absolute dim counting from the right.
__call__(name, fn, obs)[source]

StudentT Distributions

class StudentTReparam[source]

Bases: pyro.infer.reparam.reparam.Reparam

Auxiliary variable reparameterizer for StudentT random variables.

This is useful in combination with LinearHMMReparam because it allows StudentT processes to be treated as conditionally Gaussian processes, permitting cheap inference via GaussianHMM .

This reparameterizes a StudentT by introducing an auxiliary Gamma variable conditioned on which the result is Normal .

__call__(name, fn, obs)[source]

Stable Distributions

class LatentStableReparam[source]

Bases: pyro.infer.reparam.reparam.Reparam

Auxiliary variable reparameterizer for Stable latent variables.

This is useful in inference of latent Stable variables because the log_prob() is not implemented.

This uses the Chambers-Mallows-Stuck method [1], creating a pair of parameter-free auxiliary distributions (Uniform(-pi/2,pi/2) and Exponential(1)) with well-defined .log_prob() methods, thereby permitting use of reparameterized stable distributions in likelihood-based inference algorithms like SVI and MCMC.

This reparameterization works only for latent variables, not likelihoods. For likelihood-compatible reparameterization see SymmetricStableReparam or StableReparam .

[1] J.P. Nolan (2017).
Stable Distributions: Models for Heavy Tailed Data. http://fs2.american.edu/jpnolan/www/stable/chap1.pdf
__call__(name, fn, obs)[source]
class SymmetricStableReparam[source]

Bases: pyro.infer.reparam.reparam.Reparam

Auxiliary variable reparameterizer for symmetric Stable random variables (i.e. those for which skew=0).

This is useful in inference of symmetric Stable variables because the log_prob() is not implemented.

This reparameterizes a symmetric Stable random variable as a totally-skewed (skew=1) Stable scale mixture of Normal random variables. See Proposition 3. of [1] (but note we differ since Stable uses Nolan’s continuous S0 parameterization).

[1] Alvaro Cartea and Sam Howison (2009)
“Option Pricing with Levy-Stable Processes” https://pdfs.semanticscholar.org/4d66/c91b136b2a38117dd16c2693679f5341c616.pdf
__call__(name, fn, obs)[source]
class StableReparam[source]

Bases: pyro.infer.reparam.reparam.Reparam

Auxiliary variable reparameterizer for arbitrary Stable random variables.

This is useful in inference of non-symmetric Stable variables because the log_prob() is not implemented.

This reparameterizes a Stable random variable as sum of two other stable random variables, one symmetric and the other totally skewed (applying Property 2.3.a of [1]). The totally skewed variable is sampled as in LatentStableReparam , and the symmetric variable is decomposed as in SymmetricStableReparam .

[1] V. M. Zolotarev (1986)
“One-dimensional stable distributions”
__call__(name, fn, obs)[source]

Hidden Markov Models

class LinearHMMReparam(init=None, trans=None, obs=None)[source]

Bases: pyro.infer.reparam.reparam.Reparam

Auxiliary variable reparameterizer for LinearHMM random variables.

This defers to component reparameterizers to create auxiliary random variables conditioned on which the process becomes a GaussianHMM . If the observation_dist is a TransformedDistribution this reorders those transforms so that the result is a TransformedDistribution of GaussianHMM .

This is useful for training the parameters of a LinearHMM distribution, whose log_prob() method is undefined. To perform inference in the presence of non-Gaussian factors such as Stable(), StudentT() or LogNormal() , configure with StudentTReparam , StableReparam , SymmetricStableReparam , etc. component reparameterizers for init, trans, and scale. For example:

hmm = LinearHMM(
    init_dist=Stable(1,0,1,0).expand([2]).to_event(1),
    trans_matrix=torch.eye(2),
    trans_dist=MultivariateNormal(torch.zeros(2), torch.eye(2)),
    obs_matrix=torch.eye(2),
    obs_dist=TransformedDistribution(
        Stable(1.5,-0.5,1.0).expand([2]).to_event(1),
        ExpTransform()))

rep = LinearHMMReparam(init=SymmetricStableReparam(),
                       obs=StableReparam())

with poutine.reparam(config={"hmm": rep}):
    pyro.sample("hmm", hmm, obs=data)
Parameters:
  • init (Reparam) – Optional reparameterizer for the initial distribution.
  • trans (Reparam) – Optional reparameterizer for the transition distribution.
  • obs (Reparam) – Optional reparameterizer for the observation distribution.
__call__(name, fn, obs)[source]

Neural Transport

class NeuTraReparam(guide)[source]

Bases: pyro.infer.reparam.reparam.Reparam

Neural Transport reparameterizer [1] of multiple latent variables.

This uses a trained AutoContinuous guide to alter the geometry of a model, typically for use e.g. in MCMC. Example usage:

# Step 1. Train a guide
guide = AutoIAFNormal(model)
svi = SVI(model, guide, ...)
# ...train the guide...

# Step 2. Use trained guide in NeuTra MCMC
neutra = NeuTraReparam(guide)
model = poutine.reparam(model, config=lambda _: neutra)
nuts = NUTS(model)
# ...now use the model in HMC or NUTS...

This reparameterization works only for latent variables, not likelihoods. Note that all sites must share a single common NeuTraReparam instance, and that the model must have static structure.

[1] Hoffman, M. et al. (2019)
“NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport” https://arxiv.org/abs/1903.03704
Parameters:guide (AutoContinuous) – A trained guide.
reparam(fn=None)[source]
__call__(name, fn, obs)[source]
transform_sample(latent)[source]

Given latent samples from the warped posterior (with a possible batch dimension), return a dict of samples from the latent sites in the model.

Parameters:latent – sample from the warped posterior (possibly batched). Note that the batch dimension must not collide with plate dimensions in the model, i.e. any batch dims d < - max_plate_nesting.
Returns:a dict of samples keyed by latent sites in the model.
Return type:dict