# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import math
import warnings
from functools import partial
import torch
import torch.nn as nn
from torch.distributions import Transform, constraints
from pyro.distributions.conditional import ConditionalTransformModule
from pyro.distributions.torch_transform import TransformModule
from pyro.distributions.util import copy_docs_from
from pyro.nn import DenseNN
class ConditionedHouseholder(Transform):
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
event_dim = 1
volume_preserving = True
def __init__(self, u_unnormed=None):
super().__init__(cache_size=1)
self.u_unnormed = u_unnormed
# Construct normalized vectors for Householder transform
def u(self):
u_unnormed = self.u_unnormed() if callable(self.u_unnormed) else self.u_unnormed
norm = torch.norm(u_unnormed, p=2, dim=-1, keepdim=True)
return torch.div(u_unnormed, norm)
def _call(self, x):
"""
:param x: the input into the bijection
:type x: torch.Tensor
Invokes the bijection x=>y; in the prototypical context of a
:class:`~pyro.distributions.TransformedDistribution` `x` is a sample from
the base distribution (or the output of a previous transform)
"""
y = x
u = self.u()
for idx in range(u.size(-2)):
projection = (u[..., idx, :] * y).sum(dim=-1, keepdim=True) * u[..., idx, :]
y = y - 2. * projection
return y
def _inverse(self, y):
"""
:param y: the output of the bijection
:type y: torch.Tensor
Inverts y => x. The Householder transformation, H, is "involutory," i.e.
H^2 = I. If you reflect a point around a plane, then the same operation will
reflect it back
"""
x = y
u = self.u()
for jdx in reversed(range(u.size(-2))):
# NOTE: Need to apply transforms in reverse order from forward operation!
projection = (u[..., jdx, :] * x).sum(dim=-1, keepdim=True) * u[..., jdx, :]
x = x - 2. * projection
return x
def log_abs_det_jacobian(self, x, y):
r"""
Calculates the elementwise determinant of the log jacobian. Householder flow
is measure preserving, so :math:`\log(|detJ|) = 0`
"""
return torch.zeros(x.size()[:-1], dtype=x.dtype, layout=x.layout, device=x.device)
[docs]@copy_docs_from(TransformModule)
class Householder(ConditionedHouseholder, TransformModule):
r"""
Represents multiple applications of the Householder bijective transformation. A
single Householder transformation takes the form,
:math:`\mathbf{y} = (I - 2*\frac{\mathbf{u}\mathbf{u}^T}{||\mathbf{u}||^2})\mathbf{x}`
where :math:`\mathbf{x}` are the inputs, :math:`\mathbf{y}` are the outputs,
and the learnable parameters are :math:`\mathbf{u}\in\mathbb{R}^D` for input
dimension :math:`D`.
The transformation represents the reflection of :math:`\mathbf{x}` through the
plane passing through the origin with normal :math:`\mathbf{u}`.
:math:`D` applications of this transformation are able to transform standard
i.i.d. standard Gaussian noise into a Gaussian variable with an arbitrary
covariance matrix. With :math:`K<D` transformations, one is able to approximate
a full-rank Gaussian distribution using a linear transformation of rank
:math:`K`.
Together with :class:`~pyro.distributions.TransformedDistribution` this provides
a way to create richer variational approximations.
Example usage:
>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
>>> transform = Householder(10, count_transforms=5)
>>> pyro.module("my_transform", p) # doctest: +SKIP
>>> flow_dist = dist.TransformedDistribution(base_dist, [transform])
>>> flow_dist.sample() # doctest: +SKIP
:param input_dim: the dimension of the input (and output) variable.
:type input_dim: int
:param count_transforms: number of applications of Householder transformation to
apply.
:type count_transforms: int
References:
[1] Jakub M. Tomczak, Max Welling. Improving Variational Auto-Encoders using
Householder Flow. [arXiv:1611.09630]
"""
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
event_dim = 1
volume_preserving = True
def __init__(self, input_dim, count_transforms=1):
super().__init__()
self.input_dim = input_dim
if count_transforms < 1:
raise ValueError('Number of Householder transforms, {}, is less than 1!'.format(count_transforms))
elif count_transforms > input_dim:
warnings.warn(
"Number of Householder transforms, {}, is greater than input dimension {}, which is an \
over-parametrization!".format(count_transforms, input_dim))
self.u_unnormed = nn.Parameter(torch.Tensor(count_transforms, input_dim))
self.reset_parameters()
[docs] def reset_parameters(self):
stdv = 1. / math.sqrt(self.u_unnormed.size(-1))
self.u_unnormed.data.uniform_(-stdv, stdv)
[docs]@copy_docs_from(ConditionalTransformModule)
class ConditionalHouseholder(ConditionalTransformModule):
r"""
Represents multiple applications of the Householder bijective transformation
conditioning on an additional context. A single Householder transformation takes
the form,
:math:`\mathbf{y} = (I - 2*\frac{\mathbf{u}\mathbf{u}^T}{||\mathbf{u}||^2})\mathbf{x}`
where :math:`\mathbf{x}` are the inputs with dimension :math:`D`,
:math:`\mathbf{y}` are the outputs, and :math:`\mathbf{u}\in\mathbb{R}^D`
is the output of a function, e.g. a NN, with input :math:`z\in\mathbb{R}^{M}`
representing the context variable to condition on.
The transformation represents the reflection of :math:`\mathbf{x}` through the
plane passing through the origin with normal :math:`\mathbf{u}`.
:math:`D` applications of this transformation are able to transform standard
i.i.d. standard Gaussian noise into a Gaussian variable with an arbitrary
covariance matrix. With :math:`K<D` transformations, one is able to approximate
a full-rank Gaussian distribution using a linear transformation of rank
:math:`K`.
Together with :class:`~pyro.distributions.ConditionalTransformedDistribution`
this provides a way to create richer variational approximations.
Example usage:
>>> from pyro.nn.dense_nn import DenseNN
>>> input_dim = 10
>>> context_dim = 5
>>> batch_size = 3
>>> base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
>>> param_dims = [input_dim]
>>> hypernet = DenseNN(context_dim, [50, 50], param_dims)
>>> transform = ConditionalHouseholder(input_dim, hypernet)
>>> z = torch.rand(batch_size, context_dim)
>>> flow_dist = dist.ConditionalTransformedDistribution(base_dist,
... [transform]).condition(z)
>>> flow_dist.sample(sample_shape=torch.Size([batch_size])) # doctest: +SKIP
:param input_dim: the dimension of the input (and output) variable.
:type input_dim: int
:param nn: a function inputting the context variable and outputting a triplet of
real-valued parameters of dimensions :math:`(1, D, D)`.
:type nn: callable
:param count_transforms: number of applications of Householder transformation to
apply.
:type count_transforms: int
References:
[1] Jakub M. Tomczak, Max Welling. Improving Variational Auto-Encoders using
Householder Flow. [arXiv:1611.09630]
"""
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
event_dim = 1
def __init__(self, input_dim, nn, count_transforms=1):
super().__init__()
self.nn = nn
self.input_dim = input_dim
if count_transforms < 1:
raise ValueError('Number of Householder transforms, {}, is less than 1!'.format(count_transforms))
elif count_transforms > input_dim:
warnings.warn(
"Number of Householder transforms, {}, is greater than input dimension {}, which is an \
over-parametrization!".format(count_transforms, input_dim))
self.count_transforms = count_transforms
def _u_unnormed(self, context):
# u_unnormed ~ (count_transforms, input_dim)
# Hence, input_dim must divide
u_unnormed = self.nn(context)
if self.count_transforms == 1:
u_unnormed = u_unnormed.unsqueeze(-2)
else:
u_unnormed = torch.stack(u_unnormed, dim=-2)
return u_unnormed
[docs] def condition(self, context):
u_unnormed = partial(self._u_unnormed, context)
return ConditionedHouseholder(u_unnormed)
[docs]def householder(input_dim, count_transforms=None):
"""
A helper function to create a
:class:`~pyro.distributions.transforms.Householder` object for consistency with
other helpers.
:param input_dim: Dimension of input variable
:type input_dim: int
:param count_transforms: number of applications of Householder transformation to
apply.
:type count_transforms: int
"""
if count_transforms is None:
count_transforms = input_dim // 2 + 1
return Householder(input_dim, count_transforms=count_transforms)
[docs]def conditional_householder(input_dim, context_dim, hidden_dims=None, count_transforms=1):
"""
A helper function to create a
:class:`~pyro.distributions.transforms.ConditionalHouseholder` object that takes
care of constructing a dense network with the correct input/output dimensions.
:param input_dim: Dimension of input variable
:type input_dim: int
:param context_dim: Dimension of context variable
:type context_dim: int
:param hidden_dims: The desired hidden dimensions of the dense network. Defaults
to using [input_dim * 10, input_dim * 10]
:type hidden_dims: list[int]
"""
if hidden_dims is None:
hidden_dims = [input_dim * 10, input_dim * 10]
nn = DenseNN(context_dim, hidden_dims, param_dims=[input_dim] * count_transforms)
return ConditionalHouseholder(input_dim, nn, count_transforms)