# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import torch
from torch.distributions import Transform
from pyro.nn import DenseNN
from .. import constraints
from ..conditional import ConditionalTransformModule
from ..torch_transform import TransformModule
from ..util import copy_docs_from
@copy_docs_from(Transform)
class ConditionedGeneralizedChannelPermute(Transform):
domain = constraints.independent(constraints.real, 3)
codomain = constraints.independent(constraints.real, 3)
bijective = True
def __init__(self, permutation=None, LU=None):
super(ConditionedGeneralizedChannelPermute, self).__init__(cache_size=1)
self.permutation = permutation
self.LU = LU
@property
def U_diag(self):
return self.LU.diag()
@property
def L(self):
return self.LU.tril(diagonal=-1) + torch.eye(
self.LU.size(-1), dtype=self.LU.dtype, device=self.LU.device
)
@property
def U(self):
return self.LU.triu()
def _call(self, x):
"""
:param x: the input into the bijection
:type x: torch.Tensor
Invokes the bijection x=>y; in the prototypical context of a
:class:`~pyro.distributions.TransformedDistribution` `x` is a sample from
the base distribution (or the output of a previous transform)
"""
"""
NOTE: As is the case for other conditional transforms, the batch dim of the
context variable (reflected in the initial dimensions of filters in this
case), if this is a conditional transform, must broadcast over the batch dim
of the input variable.
Also, the reason the following line uses matrix multiplication rather than
F.conv2d is so we can perform multiple convolutions when the filters
"kernel" has batch dimensions
"""
filters = (self.permutation @ self.L @ self.U)[..., None, None]
y = (filters * x.unsqueeze(-4)).sum(-3)
return y
def _inverse(self, y):
"""
:param y: the output of the bijection
:type y: torch.Tensor
Inverts y => x.
"""
"""
NOTE: This method is equivalent to the following two lines. Using
Tensor.inverse() would be numerically unstable, however.
filters = (self.permutation @ self.L @ self.U).inverse()[..., None, None]
x = F.conv2d(y.view(-1, *y.shape[-3:]), filters)
return x.view_as(y)
"""
# Do a matrix vector product over the channel dimension
# in order to apply inverse permutation matrix
y_flat = y.flatten(start_dim=-2)
LUx = (y_flat.unsqueeze(-3) * self.permutation.T.unsqueeze(-1)).sum(-2)
# Solve L(Ux) = P^1y
Ux = torch.linalg.solve_triangular(self.L, LUx, upper=False)
# Solve Ux = (PL)^-1y
x = torch.linalg.solve_triangular(self.U, Ux, upper=True)
# Unflatten x (works when context variable has batch dim)
return x.reshape(x.shape[:-1] + y.shape[-2:])
def log_abs_det_jacobian(self, x, y):
"""
Calculates the elementwise determinant of the log Jacobian, i.e.
log(abs(det(dy/dx))).
"""
h, w = x.shape[-2:]
log_det = h * w * self.U_diag.abs().log().sum()
return log_det * torch.ones(
x.size()[:-3], dtype=x.dtype, layout=x.layout, device=x.device
)
[docs]@copy_docs_from(ConditionedGeneralizedChannelPermute)
class GeneralizedChannelPermute(ConditionedGeneralizedChannelPermute, TransformModule):
r"""
A bijection that generalizes a permutation on the channels of a batch of 2D
image in :math:`[\ldots,C,H,W]` format. Specifically this transform performs
the operation,
:math:`\mathbf{y} = \text{torch.nn.functional.conv2d}(\mathbf{x}, W)`
where :math:`\mathbf{x}` are the inputs, :math:`\mathbf{y}` are the outputs,
and :math:`W\sim C\times C\times 1\times 1` is the filter matrix for a 1x1
convolution with :math:`C` input and output channels.
Ignoring the final two dimensions, :math:`W` is restricted to be the matrix
product,
:math:`W = PLU`
where :math:`P\sim C\times C` is a permutation matrix on the channel
dimensions, :math:`L\sim C\times C` is a lower triangular matrix with ones on
the diagonal, and :math:`U\sim C\times C` is an upper triangular matrix.
:math:`W` is initialized to a random orthogonal matrix. Then, :math:`P` is fixed
and the learnable parameters set to :math:`L,U`.
The input :math:`\mathbf{x}` and output :math:`\mathbf{y}` both have shape
`[...,C,H,W]`, where `C` is the number of channels set at initialization.
This operation was introduced in [1] for Glow normalizing flow, and is also
known as 1x1 invertible convolution. It appears in other notable work such as
[2,3], and corresponds to the class `tfp.bijectors.MatvecLU` of TensorFlow
Probability.
Example usage:
>>> channels = 3
>>> base_dist = dist.Normal(torch.zeros(channels, 32, 32),
... torch.ones(channels, 32, 32))
>>> inv_conv = GeneralizedChannelPermute(channels=channels)
>>> flow_dist = dist.TransformedDistribution(base_dist, [inv_conv])
>>> flow_dist.sample() # doctest: +SKIP
:param channels: Number of channel dimensions in the input.
:type channels: int
[1] Diederik P. Kingma, Prafulla Dhariwal. Glow: Generative Flow with Invertible
1x1 Convolutions. [arXiv:1807.03039]
[2] Ryan Prenger, Rafael Valle, Bryan Catanzaro. WaveGlow: A Flow-based
Generative Network for Speech Synthesis. [arXiv:1811.00002]
[3] Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios. Neural Spline
Flows. [arXiv:1906.04032]
"""
domain = constraints.independent(constraints.real, 3)
codomain = constraints.independent(constraints.real, 3)
bijective = True
def __init__(self, channels=3, permutation=None):
super(GeneralizedChannelPermute, self).__init__()
self.__delattr__("permutation")
# Sample a random orthogonal matrix
W, _ = torch.linalg.qr(torch.randn(channels, channels))
# Construct the partially pivoted LU-form and the pivots
LU, pivots = torch.linalg.lu_factor(W)
# Convert the pivots into the permutation matrix
if permutation is None:
P, _, _ = torch.lu_unpack(LU, pivots)
else:
if len(permutation) != channels:
raise ValueError(
'Keyword argument "permutation" expected to have {} elements but {} found.'.format(
channels, len(permutation)
)
)
P = torch.eye(channels, channels)[permutation.type(dtype=torch.int64)]
# We register the permutation matrix so that the model can be serialized
self.register_buffer("permutation", P)
# NOTE: For this implementation I have chosen to store the parameters densely, rather than
# storing L, U, and s separately
self.LU = torch.nn.Parameter(LU)
[docs]@copy_docs_from(ConditionalTransformModule)
class ConditionalGeneralizedChannelPermute(ConditionalTransformModule):
r"""
A bijection that generalizes a permutation on the channels of a batch of 2D
image in :math:`[\ldots,C,H,W]` format conditioning on an additional context
variable. Specifically this transform performs the operation,
:math:`\mathbf{y} = \text{torch.nn.functional.conv2d}(\mathbf{x}, W)`
where :math:`\mathbf{x}` are the inputs, :math:`\mathbf{y}` are the outputs,
and :math:`W\sim C\times C\times 1\times 1` is the filter matrix for a 1x1
convolution with :math:`C` input and output channels.
Ignoring the final two dimensions, :math:`W` is restricted to be the matrix
product,
:math:`W = PLU`
where :math:`P\sim C\times C` is a permutation matrix on the channel
dimensions, and :math:`LU\sim C\times C` is an invertible product of a lower
triangular and an upper triangular matrix that is the output of an NN with
input :math:`z\in\mathbb{R}^{M}` representing the context variable to
condition on.
The input :math:`\mathbf{x}` and output :math:`\mathbf{y}` both have shape
`[...,C,H,W]`, where `C` is the number of channels set at initialization.
This operation was introduced in [1] for Glow normalizing flow, and is also
known as 1x1 invertible convolution. It appears in other notable work such as
[2,3], and corresponds to the class `tfp.bijectors.MatvecLU` of TensorFlow
Probability.
Example usage:
>>> from pyro.nn.dense_nn import DenseNN
>>> context_dim = 5
>>> batch_size = 3
>>> channels = 3
>>> base_dist = dist.Normal(torch.zeros(channels, 32, 32),
... torch.ones(channels, 32, 32))
>>> hidden_dims = [context_dim*10, context_dim*10]
>>> nn = DenseNN(context_dim, hidden_dims, param_dims=[channels*channels])
>>> transform = ConditionalGeneralizedChannelPermute(nn, channels=channels)
>>> z = torch.rand(batch_size, context_dim)
>>> flow_dist = dist.ConditionalTransformedDistribution(base_dist,
... [transform]).condition(z)
>>> flow_dist.sample(sample_shape=torch.Size([batch_size])) # doctest: +SKIP
:param nn: a function inputting the context variable and outputting
real-valued parameters of dimension :math:`C^2`.
:param channels: Number of channel dimensions in the input.
:type channels: int
[1] Diederik P. Kingma, Prafulla Dhariwal. Glow: Generative Flow with Invertible
1x1 Convolutions. [arXiv:1807.03039]
[2] Ryan Prenger, Rafael Valle, Bryan Catanzaro. WaveGlow: A Flow-based
Generative Network for Speech Synthesis. [arXiv:1811.00002]
[3] Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios. Neural Spline
Flows. [arXiv:1906.04032]
"""
domain = constraints.independent(constraints.real, 3)
codomain = constraints.independent(constraints.real, 3)
bijective = True
def __init__(self, nn, channels=3, permutation=None):
super().__init__()
self.nn = nn
self.channels = channels
if permutation is None:
permutation = torch.randperm(channels, device="cpu").to(
torch.Tensor().device
)
P = torch.eye(len(permutation), len(permutation))[
permutation.type(dtype=torch.int64)
]
self.register_buffer("permutation", P)
[docs] def condition(self, context):
LU = self.nn(context)
LU = LU.view(LU.shape[:-1] + (self.channels, self.channels))
return ConditionedGeneralizedChannelPermute(self.permutation, LU)
[docs]def generalized_channel_permute(**kwargs):
"""
A helper function to create a
:class:`~pyro.distributions.transforms.GeneralizedChannelPermute` object for
consistency with other helpers.
:param channels: Number of channel dimensions in the input.
:type channels: int
"""
return GeneralizedChannelPermute(**kwargs)
[docs]def conditional_generalized_channel_permute(context_dim, channels=3, hidden_dims=None):
"""
A helper function to create a
:class:`~pyro.distributions.transforms.ConditionalGeneralizedChannelPermute`
object for consistency with other helpers.
:param channels: Number of channel dimensions in the input.
:type channels: int
"""
if hidden_dims is None:
hidden_dims = [channels * 10, channels * 10]
nn = DenseNN(context_dim, hidden_dims, param_dims=[channels * channels])
return ConditionalGeneralizedChannelPermute(nn, channels)