Source code for pyro.distributions.transforms.haar
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from torch.distributions.transforms import Transform
from pyro.ops.tensor_utils import haar_transform, inverse_haar_transform
from .. import constraints
[docs]class HaarTransform(Transform):
Discrete Haar transform.
This uses :func:`~pyro.ops.tensor_utils.haar_transform` and
:func:`~pyro.ops.tensor_utils.inverse_haar_transform` to compute
(orthonormal) Haar and inverse Haar transforms. The jacobian is 1.
For sequences with length `T` not a power of two, this implementation
is equivalent to a block-structured Haar transform in which block
sizes decrease by factors of one half from left to right.
:param int dim: Dimension along which to transform. Must be negative.
This is an absolute dim counting from the right.
:param bool flip: Whether to flip the time axis before applying the
Haar transform. Defaults to false.
bijective = True
def __init__(self, dim=-1, flip=False, cache_size=0):
assert isinstance(dim, int) and dim < 0
self.dim = dim
self.flip = flip
def __hash__(self):
return hash((type(self), self.event_dim, self.flip))
def __eq__(self, other):
return (
type(self) == type(other)
and self.dim == other.dim
and self.flip == other.flip
def domain(self):
return constraints.independent(constraints.real, -self.dim)
def codomain(self):
return constraints.independent(constraints.real, -self.dim)
def _call(self, x):
dim = self.dim
if dim != -1:
x = x.transpose(dim, -1)
if self.flip:
x = x.flip(-1)
y = haar_transform(x)
if dim != -1:
y = y.transpose(dim, -1)
return y
def _inverse(self, y):
dim = self.dim
if dim != -1:
y = y.transpose(dim, -1)
x = inverse_haar_transform(y)
if self.flip:
x = x.flip(-1)
if dim != -1:
x = x.transpose(dim, -1)
return x
[docs] def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return HaarTransform(self.dim, flip=self.flip, cache_size=cache_size)
[docs] def forward_shape(self, shape):
if len(shape) < self.event_dim:
raise ValueError("Too few dimensions on input")
return shape
[docs] def inverse_shape(self, shape):
if len(shape) < self.event_dim:
raise ValueError("Too few dimensions on input")
return shape