# Source code for pyro.distributions.transforms.householder

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import math
import warnings
from functools import partial

import torch
import torch.nn as nn
from torch.distributions import Transform, constraints

from pyro.nn import DenseNN

from ..conditional import ConditionalTransformModule
from ..torch_transform import TransformModule
from ..util import copy_docs_from

class ConditionedHouseholder(Transform):
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
volume_preserving = True

def __init__(self, u_unnormed=None):
super().__init__(cache_size=1)
self.u_unnormed = u_unnormed

# Construct normalized vectors for Householder transform
def u(self):
u_unnormed = self.u_unnormed() if callable(self.u_unnormed) else self.u_unnormed
norm = torch.norm(u_unnormed, p=2, dim=-1, keepdim=True)

def _call(self, x):
"""
:param x: the input into the bijection
:type x: torch.Tensor

Invokes the bijection x=>y; in the prototypical context of a
:class:~pyro.distributions.TransformedDistribution x is a sample from
the base distribution (or the output of a previous transform)
"""

y = x
u = self.u()
for idx in range(u.size(-2)):
projection = (u[..., idx, :] * y).sum(dim=-1, keepdim=True) * u[..., idx, :]
y = y - 2.0 * projection
return y

def _inverse(self, y):
"""
:param y: the output of the bijection
:type y: torch.Tensor

Inverts y => x. The Householder transformation, H, is "involutory," i.e.
H^2 = I. If you reflect a point around a plane, then the same operation will
reflect it back
"""

x = y
u = self.u()
for jdx in reversed(range(u.size(-2))):
# NOTE: Need to apply transforms in reverse order from forward operation!
projection = (u[..., jdx, :] * x).sum(dim=-1, keepdim=True) * u[..., jdx, :]
x = x - 2.0 * projection
return x

def log_abs_det_jacobian(self, x, y):
r"""
Calculates the elementwise determinant of the log jacobian. Householder flow
is measure preserving, so :math:\log(|detJ|) = 0
"""

x.size()[:-1], dtype=x.dtype, layout=x.layout, device=x.device
)

[docs]@copy_docs_from(TransformModule)
class Householder(ConditionedHouseholder, TransformModule):
r"""
Represents multiple applications of the Householder bijective transformation. A
single Householder transformation takes the form,

:math:\mathbf{y} = (I - 2*\frac{\mathbf{u}\mathbf{u}^T}{||\mathbf{u}||^2})\mathbf{x}

where :math:\mathbf{x} are the inputs, :math:\mathbf{y} are the outputs,
and the learnable parameters are :math:\mathbf{u}\in\mathbb{R}^D for input
dimension :math:D.

The transformation represents the reflection of :math:\mathbf{x} through the
plane passing through the origin with normal :math:\mathbf{u}.

:math:D applications of this transformation are able to transform standard
i.i.d. standard Gaussian noise into a Gaussian variable with an arbitrary
covariance matrix. With :math:K<D transformations, one is able to approximate
a full-rank Gaussian distribution using a linear transformation of rank
:math:K.

Together with :class:~pyro.distributions.TransformedDistribution this provides
a way to create richer variational approximations.

Example usage:

>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
>>> transform = Householder(10, count_transforms=5)
>>> pyro.module("my_transform", p) # doctest: +SKIP
>>> flow_dist = dist.TransformedDistribution(base_dist, [transform])
>>> flow_dist.sample()  # doctest: +SKIP

:param input_dim: the dimension of the input (and output) variable.
:type input_dim: int
:param count_transforms: number of applications of Householder transformation to
apply.
:type count_transforms: int

References:

 Jakub M. Tomczak, Max Welling. Improving Variational Auto-Encoders using
Householder Flow. [arXiv:1611.09630]

"""

domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
volume_preserving = True

def __init__(self, input_dim, count_transforms=1):
super().__init__()

self.input_dim = input_dim
if count_transforms < 1:
raise ValueError(
"Number of Householder transforms, {}, is less than 1!".format(
count_transforms
)
)
elif count_transforms > input_dim:
warnings.warn(
"Number of Householder transforms, {}, is greater than input dimension {}, which is an \
over-parametrization!".format(
count_transforms, input_dim
)
)
self.u_unnormed = nn.Parameter(torch.Tensor(count_transforms, input_dim))
self.reset_parameters()

[docs]    def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.u_unnormed.size(-1))
self.u_unnormed.data.uniform_(-stdv, stdv)

[docs]@copy_docs_from(ConditionalTransformModule)
class ConditionalHouseholder(ConditionalTransformModule):
r"""
Represents multiple applications of the Householder bijective transformation
conditioning on an additional context. A single Householder transformation takes
the form,

:math:\mathbf{y} = (I - 2*\frac{\mathbf{u}\mathbf{u}^T}{||\mathbf{u}||^2})\mathbf{x}

where :math:\mathbf{x} are the inputs with dimension :math:D,
:math:\mathbf{y} are the outputs, and :math:\mathbf{u}\in\mathbb{R}^D
is the output of a function, e.g. a NN, with input :math:z\in\mathbb{R}^{M}
representing the context variable to condition on.

The transformation represents the reflection of :math:\mathbf{x} through the
plane passing through the origin with normal :math:\mathbf{u}.

:math:D applications of this transformation are able to transform standard
i.i.d. standard Gaussian noise into a Gaussian variable with an arbitrary
covariance matrix. With :math:K<D transformations, one is able to approximate
a full-rank Gaussian distribution using a linear transformation of rank
:math:K.

Together with :class:~pyro.distributions.ConditionalTransformedDistribution
this provides a way to create richer variational approximations.

Example usage:

>>> from pyro.nn.dense_nn import DenseNN
>>> input_dim = 10
>>> context_dim = 5
>>> batch_size = 3
>>> base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
>>> param_dims = [input_dim]
>>> hypernet = DenseNN(context_dim, [50, 50], param_dims)
>>> transform = ConditionalHouseholder(input_dim, hypernet)
>>> z = torch.rand(batch_size, context_dim)
>>> flow_dist = dist.ConditionalTransformedDistribution(base_dist,
... [transform]).condition(z)
>>> flow_dist.sample(sample_shape=torch.Size([batch_size])) # doctest: +SKIP

:param input_dim: the dimension of the input (and output) variable.
:type input_dim: int
:param nn: a function inputting the context variable and outputting a triplet of
real-valued parameters of dimensions :math:(1, D, D).
:type nn: callable
:param count_transforms: number of applications of Householder transformation to
apply.
:type count_transforms: int

References:

 Jakub M. Tomczak, Max Welling. Improving Variational Auto-Encoders using
Householder Flow. [arXiv:1611.09630]

"""

domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True

def __init__(self, input_dim, nn, count_transforms=1):
super().__init__()
self.nn = nn
self.input_dim = input_dim
if count_transforms < 1:
raise ValueError(
"Number of Householder transforms, {}, is less than 1!".format(
count_transforms
)
)
elif count_transforms > input_dim:
warnings.warn(
"Number of Householder transforms, {}, is greater than input dimension {}, which is an \
over-parametrization!".format(
count_transforms, input_dim
)
)
self.count_transforms = count_transforms

def _u_unnormed(self, context):
# u_unnormed ~ (count_transforms, input_dim)
# Hence, input_dim must divide
u_unnormed = self.nn(context)
if self.count_transforms == 1:
u_unnormed = u_unnormed.unsqueeze(-2)
else:
u_unnormed = torch.stack(u_unnormed, dim=-2)
return u_unnormed

[docs]    def condition(self, context):
u_unnormed = partial(self._u_unnormed, context)
return ConditionedHouseholder(u_unnormed)

[docs]def householder(input_dim, count_transforms=None):
"""
A helper function to create a
:class:~pyro.distributions.transforms.Householder object for consistency with
other helpers.

:param input_dim: Dimension of input variable
:type input_dim: int
:param count_transforms: number of applications of Householder transformation to
apply.
:type count_transforms: int

"""

if count_transforms is None:
count_transforms = input_dim // 2 + 1
return Householder(input_dim, count_transforms=count_transforms)

[docs]def conditional_householder(
input_dim, context_dim, hidden_dims=None, count_transforms=1
):
"""
A helper function to create a
:class:~pyro.distributions.transforms.ConditionalHouseholder object that takes
care of constructing a dense network with the correct input/output dimensions.

:param input_dim: Dimension of input variable
:type input_dim: int
:param context_dim: Dimension of context variable
:type context_dim: int
:param hidden_dims: The desired hidden dimensions of the dense network. Defaults
to using [input_dim * 10, input_dim * 10]
:type hidden_dims: list[int]

"""

if hidden_dims is None:
hidden_dims = [input_dim * 10, input_dim * 10]
nn = DenseNN(context_dim, hidden_dims, param_dims=[input_dim] * count_transforms)
return ConditionalHouseholder(input_dim, nn, count_transforms)