# Source code for pyro.distributions.transforms.haar

```# Copyright Contributors to the Pyro project.

from torch.distributions.transforms import Transform

from pyro.ops.tensor_utils import haar_transform, inverse_haar_transform

from .. import constraints

[docs]class HaarTransform(Transform):
"""
Discrete Haar transform.

This uses :func:`~pyro.ops.tensor_utils.haar_transform` and
:func:`~pyro.ops.tensor_utils.inverse_haar_transform` to compute
(orthonormal) Haar and inverse Haar transforms. The jacobian is 1.
For sequences with length `T` not a power of two, this implementation
is equivalent to a block-structured Haar transform in which block
sizes decrease by factors of one half from left to right.

:param int dim: Dimension along which to transform. Must be negative.
This is an absolute dim counting from the right.
:param bool flip: Whether to flip the time axis before applying the
Haar transform. Defaults to false.
"""

bijective = True

def __init__(self, dim=-1, flip=False, cache_size=0):
assert isinstance(dim, int) and dim < 0
self.dim = dim
self.flip = flip
super().__init__(cache_size=cache_size)

def __hash__(self):
return hash((type(self), self.event_dim, self.flip))

def __eq__(self, other):
return (
type(self) == type(other)
and self.dim == other.dim
and self.flip == other.flip
)

@constraints.dependent_property(is_discrete=False)
def domain(self):
return constraints.independent(constraints.real, -self.dim)

@constraints.dependent_property(is_discrete=False)
def codomain(self):
return constraints.independent(constraints.real, -self.dim)

def _call(self, x):
dim = self.dim
if dim != -1:
x = x.transpose(dim, -1)
if self.flip:
x = x.flip(-1)
y = haar_transform(x)
if dim != -1:
y = y.transpose(dim, -1)
return y

def _inverse(self, y):
dim = self.dim
if dim != -1:
y = y.transpose(dim, -1)
x = inverse_haar_transform(y)
if self.flip:
x = x.flip(-1)
if dim != -1:
x = x.transpose(dim, -1)
return x

[docs]    def log_abs_det_jacobian(self, x, y):
return x.new_zeros(x.shape[: self.dim])

[docs]    def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return HaarTransform(self.dim, flip=self.flip, cache_size=cache_size)

[docs]    def forward_shape(self, shape):
if len(shape) < self.event_dim:
raise ValueError("Too few dimensions on input")
return shape

[docs]    def inverse_shape(self, shape):
if len(shape) < self.event_dim:
raise ValueError("Too few dimensions on input")
return shape
```