# Source code for pyro.distributions.transforms.haar

# Copyright Contributors to the Pyro project.

from torch.distributions import constraints
from torch.distributions.transforms import Transform

from pyro.ops.tensor_utils import haar_transform, inverse_haar_transform

[docs]class HaarTransform(Transform):
"""
Discrete Haar transform.

This uses :func:~pyro.ops.tensor_utils.haar_transform and
:func:~pyro.ops.tensor_utils.inverse_haar_transform to compute
(orthonormal) Haar and inverse Haar transforms. The jacobian is 1.
For sequences with length T not a power of two, this implementation
is equivalent to a block-structured Haar transform in which block
sizes decrease by factors of one half from left to right.

:param int dim: Dimension along which to transform. Must be negative.
This is an absolute dim counting from the right.
:param bool flip: Whether to flip the time axis before applying the
Haar transform. Defaults to false.
"""
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True

def __init__(self, dim=-1, flip=False, cache_size=0):
assert isinstance(dim, int) and dim < 0
self.event_dim = -dim
self.flip = flip
super().__init__(cache_size=cache_size)

def __eq__(self, other):
return (type(self) == type(other) and self.event_dim == other.event_dim and
self.flip == other.flip)

def _call(self, x):
dim = -self.event_dim
if dim != -1:
x = x.transpose(dim, -1)
if self.flip:
x = x.flip(-1)
y = haar_transform(x)
if dim != -1:
y = y.transpose(dim, -1)
return y

def _inverse(self, y):
dim = -self.event_dim
if dim != -1:
y = y.transpose(dim, -1)
x = inverse_haar_transform(y)
if self.flip:
x = x.flip(-1)
if dim != -1:
x = x.transpose(dim, -1)
return x

[docs]    def log_abs_det_jacobian(self, x, y):
return x.new_zeros(x.shape[:-self.event_dim])

[docs]    def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return HaarTransform(-self.event_dim, flip=self.flip, cache_size=cache_size)