# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import math
from functools import partial
import torch
import torch.nn as nn
from torch.distributions import Transform, constraints
from pyro.nn import DenseNN
from ..conditional import ConditionalTransformModule
from ..torch_transform import TransformModule
from ..util import copy_docs_from
@copy_docs_from(Transform)
class ConditionedMatrixExponential(Transform):
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
def __init__(self, weights=None, iterations=8, normalization="none", bound=None):
super().__init__(cache_size=1)
assert iterations > 0
self.weights = weights
self.iterations = iterations
self.normalization = normalization
self.bound = bound
# Currently, weight and spectral normalization are unimplemented. This doesn't effect the validity of the
# bijection, although applying these norms should improve the numerical conditioning of the approximation.
if normalization == "weight" or normalization == "spectral":
raise NotImplementedError("Normalization is currently not implemented.")
elif normalization != "none":
raise ValueError("Unknown normalization method: {}".format(normalization))
def _exp(self, x, M):
"""
Performs power series approximation to the vector product of x with the
matrix exponential of M.
"""
power_term = x.unsqueeze(-1)
y = x.unsqueeze(-1)
for idx in range(self.iterations):
power_term = torch.matmul(M, power_term) / (idx + 1)
y = y + power_term
return y.squeeze(-1)
def _trace(self, M):
"""
Calculates the trace of a matrix and is able to do broadcasting over batch
dimensions, unlike `torch.trace`.
Broadcasting is necessary for the conditional version of the transform,
where `self.weights` may have batch dimensions corresponding the batch
dimensions of the context variable that was conditioned upon.
"""
return M.diagonal(dim1=-2, dim2=-1).sum(-1)
def _call(self, x):
"""
:param x: the input into the bijection
:type x: torch.Tensor
Invokes the bijection x => y; in the prototypical context of a
:class:`~pyro.distributions.TransformedDistribution` `x` is a sample from
the base distribution (or the output of a previous transform)
"""
M = self.weights() if callable(self.weights) else self.weights
return self._exp(x, M)
def _inverse(self, y):
"""
:param y: the output of the bijection
:type y: torch.Tensor
Inverts y => x.
"""
M = self.weights() if callable(self.weights) else self.weights
return self._exp(y, -M)
def log_abs_det_jacobian(self, x, y):
"""
Calculates the element-wise determinant of the log Jacobian
"""
M = self.weights() if callable(self.weights) else self.weights
return self._trace(M)
[docs]@copy_docs_from(ConditionedMatrixExponential)
class MatrixExponential(ConditionedMatrixExponential, TransformModule):
r"""
A dense matrix exponential bijective transform (Hoogeboom et al., 2020) with
equation,
:math:`\mathbf{y} = \exp(M)\mathbf{x}`
where :math:`\mathbf{x}` are the inputs, :math:`\mathbf{y}` are the outputs,
:math:`\exp(\cdot)` represents the matrix exponential, and the learnable
parameters are :math:`M\in\mathbb{R}^D\times\mathbb{R}^D` for input dimension
:math:`D`. In general, :math:`M` is not required to be invertible.
Due to the favourable mathematical properties of the matrix exponential, the
transform has an exact inverse and a log-determinate-Jacobian that scales in
time-complexity as :math:`O(D)`. Both the forward and reverse operations are
approximated with a truncated power series. For numerical stability, the
norm of :math:`M` can be restricted with the `normalization` keyword argument.
Example usage:
>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
>>> transform = MatrixExponential(10)
>>> pyro.module("my_transform", transform) # doctest: +SKIP
>>> flow_dist = dist.TransformedDistribution(base_dist, [transform])
>>> flow_dist.sample() # doctest: +SKIP
:param input_dim: the dimension of the input (and output) variable.
:type input_dim: int
:param iterations: the number of terms to use in the truncated power series that
approximates matrix exponentiation.
:type iterations: int
:param normalization: One of `['none', 'weight', 'spectral']` normalization that
selects what type of normalization to apply to the weight matrix. `weight`
corresponds to weight normalization (Salimans and Kingma, 2016) and
`spectral` to spectral normalization (Miyato et al, 2018).
:type normalization: string
:param bound: a bound on either the weight or spectral norm, when either of
those two types of regularization are chosen by the `normalization`
argument. A lower value for this results in fewer required terms of the
truncated power series to closely approximate the exact value of the matrix
exponential.
:type bound: float
References:
[1] Emiel Hoogeboom, Victor Garcia Satorras, Jakub M. Tomczak, Max Welling. The
Convolution Exponential and Generalized Sylvester Flows. [arXiv:2006.01910]
[2] Tim Salimans, Diederik P. Kingma. Weight Normalization: A Simple
Reparameterization to Accelerate Training of Deep Neural Networks.
[arXiv:1602.07868]
[3] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida. Spectral
Normalization for Generative Adversarial Networks. ICLR 2018.
"""
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
def __init__(self, input_dim, iterations=8, normalization="none", bound=None):
super().__init__(
iterations=iterations, normalization=normalization, bound=bound
)
self.weights = nn.Parameter(torch.Tensor(input_dim, input_dim))
self.reset_parameters()
[docs] def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.weights.size(0))
self.weights.data.uniform_(-stdv, stdv)
[docs]@copy_docs_from(ConditionalTransformModule)
class ConditionalMatrixExponential(ConditionalTransformModule):
r"""
A dense matrix exponential bijective transform (Hoogeboom et al., 2020) that
conditions on an additional context variable with equation,
:math:`\mathbf{y} = \exp(M)\mathbf{x}`
where :math:`\mathbf{x}` are the inputs, :math:`\mathbf{y}` are the outputs,
:math:`\exp(\cdot)` represents the matrix exponential, and
:math:`M\in\mathbb{R}^D\times\mathbb{R}^D` is the output of a neural network
conditioning on a context variable :math:`\mathbf{z}` for input dimension
:math:`D`. In general, :math:`M` is not required to be invertible.
Due to the favourable mathematical properties of the matrix exponential, the
transform has an exact inverse and a log-determinate-Jacobian that scales in
time-complexity as :math:`O(D)`. Both the forward and reverse operations are
approximated with a truncated power series. For numerical stability, the
norm of :math:`M` can be restricted with the `normalization` keyword argument.
Example usage:
>>> from pyro.nn.dense_nn import DenseNN
>>> input_dim = 10
>>> context_dim = 5
>>> batch_size = 3
>>> base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
>>> param_dims = [input_dim*input_dim]
>>> hypernet = DenseNN(context_dim, [50, 50], param_dims)
>>> transform = ConditionalMatrixExponential(input_dim, hypernet)
>>> z = torch.rand(batch_size, context_dim)
>>> flow_dist = dist.ConditionalTransformedDistribution(base_dist,
... [transform]).condition(z)
>>> flow_dist.sample(sample_shape=torch.Size([batch_size])) # doctest: +SKIP
:param input_dim: the dimension of the input (and output) variable.
:type input_dim: int
:param iterations: the number of terms to use in the truncated power series that
approximates matrix exponentiation.
:type iterations: int
:param normalization: One of `['none', 'weight', 'spectral']` normalization that
selects what type of normalization to apply to the weight matrix. `weight`
corresponds to weight normalization (Salimans and Kingma, 2016) and
`spectral` to spectral normalization (Miyato et al, 2018).
:type normalization: string
:param bound: a bound on either the weight or spectral norm, when either of
those two types of regularization are chosen by the `normalization`
argument. A lower value for this results in fewer required terms of the
truncated power series to closely approximate the exact value of the matrix
exponential.
:type bound: float
References:
[1] Emiel Hoogeboom, Victor Garcia Satorras, Jakub M. Tomczak, Max Welling. The
Convolution Exponential and Generalized Sylvester Flows. [arXiv:2006.01910]
[2] Tim Salimans, Diederik P. Kingma. Weight Normalization: A Simple
Reparameterization to Accelerate Training of Deep Neural Networks.
[arXiv:1602.07868]
[3] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida. Spectral
Normalization for Generative Adversarial Networks. ICLR 2018.
"""
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
def __init__(self, input_dim, nn, iterations=8, normalization="none", bound=None):
super().__init__()
self.input_dim = input_dim
self.nn = nn
self.iterations = iterations
self.normalization = normalization
self.bound = bound
def _params(self, context):
return self.nn(context)
[docs] def condition(self, context):
# This hack could be fixed by having a conditioning network that outputs a more general shape
cond_nn = partial(self.nn, context)
def weights():
w = cond_nn()
return w.view(w.shape[:-1] + (self.input_dim, self.input_dim))
return ConditionedMatrixExponential(
weights,
iterations=self.iterations,
normalization=self.normalization,
bound=self.bound,
)
[docs]def matrix_exponential(input_dim, iterations=8, normalization="none", bound=None):
"""
A helper function to create a
:class:`~pyro.distributions.transforms.MatrixExponential` object for consistency
with other helpers.
:param input_dim: Dimension of input variable
:type input_dim: int
:param iterations: the number of terms to use in the truncated power series that
approximates matrix exponentiation.
:type iterations: int
:param normalization: One of `['none', 'weight', 'spectral']` normalization that
selects what type of normalization to apply to the weight matrix. `weight`
corresponds to weight normalization (Salimans and Kingma, 2016) and
`spectral` to spectral normalization (Miyato et al, 2018).
:type normalization: string
:param bound: a bound on either the weight or spectral norm, when either of
those two types of regularization are chosen by the `normalization`
argument. A lower value for this results in fewer required terms of the
truncated power series to closely approximate the exact value of the matrix
exponential.
:type bound: float
"""
return MatrixExponential(
input_dim, iterations=iterations, normalization=normalization, bound=bound
)
[docs]def conditional_matrix_exponential(
input_dim,
context_dim,
hidden_dims=None,
iterations=8,
normalization="none",
bound=None,
):
"""
A helper function to create a
:class:`~pyro.distributions.transforms.ConditionalMatrixExponential` object for
consistency with other helpers.
:param input_dim: Dimension of input variable
:type input_dim: int
:param context_dim: Dimension of context variable
:type context_dim: int
:param hidden_dims: The desired hidden dimensions of the dense network. Defaults
to using [input_dim * 10, input_dim * 10]
:type hidden_dims: list[int]
:param iterations: the number of terms to use in the truncated power series that
approximates matrix exponentiation.
:type iterations: int
:param normalization: One of `['none', 'weight', 'spectral']` normalization that
selects what type of normalization to apply to the weight matrix. `weight`
corresponds to weight normalization (Salimans and Kingma, 2016) and
`spectral` to spectral normalization (Miyato et al, 2018).
:type normalization: string
:param bound: a bound on either the weight or spectral norm, when either of
those two types of regularization are chosen by the `normalization`
argument. A lower value for this results in fewer required terms of the
truncated power series to closely approximate the exact value of the matrix
exponential.
:type bound: float
"""
if hidden_dims is None:
hidden_dims = [input_dim * 10, input_dim * 10]
nn = DenseNN(context_dim, hidden_dims, param_dims=[input_dim * input_dim])
return ConditionalMatrixExponential(
input_dim, nn, iterations=iterations, normalization=normalization, bound=bound
)