Source code for pyro.distributions.transforms.matrix_exponential

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import math
from functools import partial

import torch
import torch.nn as nn
from torch.distributions import Transform, constraints

from pyro.nn import DenseNN

from ..conditional import ConditionalTransformModule
from ..torch_transform import TransformModule
from ..util import copy_docs_from

@copy_docs_from(Transform)
class ConditionedMatrixExponential(Transform):
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True

def __init__(self, weights=None, iterations=8, normalization="none", bound=None):
super().__init__(cache_size=1)
assert iterations > 0
self.weights = weights
self.iterations = iterations
self.normalization = normalization
self.bound = bound

# Currently, weight and spectral normalization are unimplemented. This doesn't effect the validity of the
# bijection, although applying these norms should improve the numerical conditioning of the approximation.
if normalization == "weight" or normalization == "spectral":
raise NotImplementedError("Normalization is currently not implemented.")
elif normalization != "none":
raise ValueError("Unknown normalization method: {}".format(normalization))

def _exp(self, x, M):
"""
Performs power series approximation to the vector product of x with the
matrix exponential of M.
"""
power_term = x.unsqueeze(-1)
y = x.unsqueeze(-1)
for idx in range(self.iterations):
power_term = torch.matmul(M, power_term) / (idx + 1)
y = y + power_term

return y.squeeze(-1)

def _trace(self, M):
"""
Calculates the trace of a matrix and is able to do broadcasting over batch
dimensions, unlike torch.trace.

Broadcasting is necessary for the conditional version of the transform,
where self.weights may have batch dimensions corresponding the batch
dimensions of the context variable that was conditioned upon.
"""
return M.diagonal(dim1=-2, dim2=-1).sum(-1)

def _call(self, x):
"""
:param x: the input into the bijection
:type x: torch.Tensor
Invokes the bijection x => y; in the prototypical context of a
:class:~pyro.distributions.TransformedDistribution x is a sample from
the base distribution (or the output of a previous transform)
"""

M = self.weights() if callable(self.weights) else self.weights
return self._exp(x, M)

def _inverse(self, y):
"""
:param y: the output of the bijection
:type y: torch.Tensor
Inverts y => x.
"""

M = self.weights() if callable(self.weights) else self.weights
return self._exp(y, -M)

def log_abs_det_jacobian(self, x, y):
"""
Calculates the element-wise determinant of the log Jacobian
"""

M = self.weights() if callable(self.weights) else self.weights
return self._trace(M)

[docs]@copy_docs_from(ConditionedMatrixExponential)
class MatrixExponential(ConditionedMatrixExponential, TransformModule):
r"""
A dense matrix exponential bijective transform (Hoogeboom et al., 2020) with
equation,

:math:\mathbf{y} = \exp(M)\mathbf{x}

where :math:\mathbf{x} are the inputs, :math:\mathbf{y} are the outputs,
:math:\exp(\cdot) represents the matrix exponential, and the learnable
parameters are :math:M\in\mathbb{R}^D\times\mathbb{R}^D for input dimension
:math:D. In general, :math:M is not required to be invertible.

Due to the favourable mathematical properties of the matrix exponential, the
transform has an exact inverse and a log-determinate-Jacobian that scales in
time-complexity as :math:O(D). Both the forward and reverse operations are
approximated with a truncated power series. For numerical stability, the
norm of :math:M can be restricted with the normalization keyword argument.

Example usage:

>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
>>> transform = MatrixExponential(10)
>>> pyro.module("my_transform", transform)  # doctest: +SKIP
>>> flow_dist = dist.TransformedDistribution(base_dist, [transform])
>>> flow_dist.sample()  # doctest: +SKIP

:param input_dim: the dimension of the input (and output) variable.
:type input_dim: int
:param iterations: the number of terms to use in the truncated power series that
approximates matrix exponentiation.
:type iterations: int
:param normalization: One of ['none', 'weight', 'spectral'] normalization that
selects what type of normalization to apply to the weight matrix. weight
corresponds to weight normalization (Salimans and Kingma, 2016) and
spectral to spectral normalization (Miyato et al, 2018).
:type normalization: string
:param bound: a bound on either the weight or spectral norm, when either of
those two types of regularization are chosen by the normalization
argument. A lower value for this results in fewer required terms of the
truncated power series to closely approximate the exact value of the matrix
exponential.
:type bound: float

References:

[1] Emiel Hoogeboom, Victor Garcia Satorras, Jakub M. Tomczak, Max Welling. The
Convolution Exponential and Generalized Sylvester Flows. [arXiv:2006.01910]
[2] Tim Salimans, Diederik P. Kingma. Weight Normalization: A Simple
Reparameterization to Accelerate Training of Deep Neural Networks.
[arXiv:1602.07868]
[3] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida. Spectral
Normalization for Generative Adversarial Networks. ICLR 2018.

"""

domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True

def __init__(self, input_dim, iterations=8, normalization="none", bound=None):
super().__init__(
iterations=iterations, normalization=normalization, bound=bound
)

self.weights = nn.Parameter(torch.Tensor(input_dim, input_dim))
self.reset_parameters()

[docs]    def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.weights.size(0))
self.weights.data.uniform_(-stdv, stdv)

[docs]@copy_docs_from(ConditionalTransformModule)
class ConditionalMatrixExponential(ConditionalTransformModule):
r"""
A dense matrix exponential bijective transform (Hoogeboom et al., 2020) that
conditions on an additional context variable with equation,

:math:\mathbf{y} = \exp(M)\mathbf{x}

where :math:\mathbf{x} are the inputs, :math:\mathbf{y} are the outputs,
:math:\exp(\cdot) represents the matrix exponential, and
:math:M\in\mathbb{R}^D\times\mathbb{R}^D is the output of a neural network
conditioning on a context variable :math:\mathbf{z} for input dimension
:math:D. In general, :math:M is not required to be invertible.

Due to the favourable mathematical properties of the matrix exponential, the
transform has an exact inverse and a log-determinate-Jacobian that scales in
time-complexity as :math:O(D). Both the forward and reverse operations are
approximated with a truncated power series. For numerical stability, the
norm of :math:M can be restricted with the normalization keyword argument.

Example usage:

>>> from pyro.nn.dense_nn import DenseNN
>>> input_dim = 10
>>> context_dim = 5
>>> batch_size = 3
>>> base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
>>> param_dims = [input_dim*input_dim]
>>> hypernet = DenseNN(context_dim, [50, 50], param_dims)
>>> transform = ConditionalMatrixExponential(input_dim, hypernet)
>>> z = torch.rand(batch_size, context_dim)
>>> flow_dist = dist.ConditionalTransformedDistribution(base_dist,
... [transform]).condition(z)
>>> flow_dist.sample(sample_shape=torch.Size([batch_size])) # doctest: +SKIP

:param input_dim: the dimension of the input (and output) variable.
:type input_dim: int
:param iterations: the number of terms to use in the truncated power series that
approximates matrix exponentiation.
:type iterations: int
:param normalization: One of ['none', 'weight', 'spectral'] normalization that
selects what type of normalization to apply to the weight matrix. weight
corresponds to weight normalization (Salimans and Kingma, 2016) and
spectral to spectral normalization (Miyato et al, 2018).
:type normalization: string
:param bound: a bound on either the weight or spectral norm, when either of
those two types of regularization are chosen by the normalization
argument. A lower value for this results in fewer required terms of the
truncated power series to closely approximate the exact value of the matrix
exponential.
:type bound: float

References:

[1] Emiel Hoogeboom, Victor Garcia Satorras, Jakub M. Tomczak, Max Welling. The
Convolution Exponential and Generalized Sylvester Flows. [arXiv:2006.01910]
[2] Tim Salimans, Diederik P. Kingma. Weight Normalization: A Simple
Reparameterization to Accelerate Training of Deep Neural Networks.
[arXiv:1602.07868]
[3] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida. Spectral
Normalization for Generative Adversarial Networks. ICLR 2018.

"""

domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True

def __init__(self, input_dim, nn, iterations=8, normalization="none", bound=None):
super().__init__()
self.input_dim = input_dim
self.nn = nn
self.iterations = iterations
self.normalization = normalization
self.bound = bound

def _params(self, context):
return self.nn(context)

[docs]    def condition(self, context):
# This hack could be fixed by having a conditioning network that outputs a more general shape
cond_nn = partial(self.nn, context)

def weights():
w = cond_nn()
return w.view(w.shape[:-1] + (self.input_dim, self.input_dim))

return ConditionedMatrixExponential(
weights,
iterations=self.iterations,
normalization=self.normalization,
bound=self.bound,
)

[docs]def matrix_exponential(input_dim, iterations=8, normalization="none", bound=None):
"""
A helper function to create a
:class:~pyro.distributions.transforms.MatrixExponential object for consistency
with other helpers.

:param input_dim: Dimension of input variable
:type input_dim: int
:param iterations: the number of terms to use in the truncated power series that
approximates matrix exponentiation.
:type iterations: int
:param normalization: One of ['none', 'weight', 'spectral'] normalization that
selects what type of normalization to apply to the weight matrix. weight
corresponds to weight normalization (Salimans and Kingma, 2016) and
spectral to spectral normalization (Miyato et al, 2018).
:type normalization: string
:param bound: a bound on either the weight or spectral norm, when either of
those two types of regularization are chosen by the normalization
argument. A lower value for this results in fewer required terms of the
truncated power series to closely approximate the exact value of the matrix
exponential.
:type bound: float

"""

return MatrixExponential(
input_dim, iterations=iterations, normalization=normalization, bound=bound
)

[docs]def conditional_matrix_exponential(
input_dim,
context_dim,
hidden_dims=None,
iterations=8,
normalization="none",
bound=None,
):
"""
A helper function to create a
:class:~pyro.distributions.transforms.ConditionalMatrixExponential object for
consistency with other helpers.

:param input_dim: Dimension of input variable
:type input_dim: int
:param context_dim: Dimension of context variable
:type context_dim: int
:param hidden_dims: The desired hidden dimensions of the dense network. Defaults
to using [input_dim * 10, input_dim * 10]
:type hidden_dims: list[int]
:param iterations: the number of terms to use in the truncated power series that
approximates matrix exponentiation.
:type iterations: int
:param normalization: One of ['none', 'weight', 'spectral'] normalization that
selects what type of normalization to apply to the weight matrix. weight
corresponds to weight normalization (Salimans and Kingma, 2016) and
spectral to spectral normalization (Miyato et al, 2018).
:type normalization: string
:param bound: a bound on either the weight or spectral norm, when either of
those two types of regularization are chosen by the normalization
argument. A lower value for this results in fewer required terms of the
truncated power series to closely approximate the exact value of the matrix
exponential.
:type bound: float

"""

if hidden_dims is None:
hidden_dims = [input_dim * 10, input_dim * 10]
nn = DenseNN(context_dim, hidden_dims, param_dims=[input_dim * input_dim])
return ConditionalMatrixExponential(
input_dim, nn, iterations=iterations, normalization=normalization, bound=bound
)