Source code for pyro.infer.reparam.haar

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from pyro.distributions.transforms.haar import HaarTransform

from .unit_jacobian import UnitJacobianReparam


[docs]class HaarReparam(UnitJacobianReparam): """ Haar wavelet reparameterizer, using a :class:`~pyro.distributions.transforms.HaarTransform`. This is useful for sequential models where coupling along a time-like axis (e.g. a banded precision matrix) introduces long-range correlation. This reparameterizes to a frequency-domain representation where posterior covariance should be closer to diagonal, thereby improving the accuracy of diagonal guides in SVI and improving the effectiveness of a diagonal mass matrix in HMC. This reparameterization works only for latent variables, not likelihoods. :param int dim: Dimension along which to transform. Must be negative. This is an absolute dim counting from the right. :param bool flip: Whether to flip the time axis before applying the Haar transform. Defaults to false. :param bool experimental_allow_batch: EXPERIMENTAL allow coupling across a batch dimension. The targeted batch dimension and all batch dimensions to the right will be converted to event dimensions. Defaults to False. """ def __init__(self, dim=-1, flip=False, *, experimental_allow_batch=False): transform = HaarTransform(dim=dim, flip=flip, cache_size=1) super().__init__( transform, suffix="haar", experimental_allow_batch=experimental_allow_batch )