# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.transforms import SigmoidTransform, TanhTransform
from pyro.nn import AutoRegressiveNN, ConditionalAutoRegressiveNN
from .. import constraints
from ..conditional import ConditionalTransformModule
from ..torch_transform import TransformModule
from ..util import copy_docs_from
from .basic import ELUTransform, LeakyReLUTransform
eps = 1e-8
[docs]@copy_docs_from(TransformModule)
class NeuralAutoregressive(TransformModule):
r"""
An implementation of the deep Neural Autoregressive Flow (NAF) bijective
transform of the "IAF flavour" that can be used for sampling and scoring samples
drawn from it (but not arbitrary ones).
Example usage:
>>> from pyro.nn import AutoRegressiveNN
>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
>>> arn = AutoRegressiveNN(10, [40], param_dims=[16]*3)
>>> transform = NeuralAutoregressive(arn, hidden_units=16)
>>> pyro.module("my_transform", transform) # doctest: +SKIP
>>> flow_dist = dist.TransformedDistribution(base_dist, [transform])
>>> flow_dist.sample() # doctest: +SKIP
The inverse operation is not implemented. This would require numerical
inversion, e.g., using a root finding method - a possibility for a future
implementation.
:param autoregressive_nn: an autoregressive neural network whose forward call
returns a tuple of three real-valued tensors, whose last dimension is the
input dimension, and whose penultimate dimension is equal to hidden_units.
:type autoregressive_nn: nn.Module
:param hidden_units: the number of hidden units to use in the NAF transformation
(see Eq (8) in reference)
:type hidden_units: int
:param activation: Activation function to use. One of 'ELU', 'LeakyReLU',
'sigmoid', or 'tanh'.
:type activation: string
Reference:
[1] Chin-Wei Huang, David Krueger, Alexandre Lacoste, Aaron Courville. Neural
Autoregressive Flows. [arXiv:1804.00779]
"""
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
eps = 1e-8
autoregressive = True
def __init__(self, autoregressive_nn, hidden_units=16, activation="sigmoid"):
super().__init__(cache_size=1)
# Create the intermediate transform used
name_to_mixin = {
"ELU": ELUTransform,
"LeakyReLU": LeakyReLUTransform,
"sigmoid": SigmoidTransform,
"tanh": TanhTransform,
}
if activation not in name_to_mixin:
raise ValueError('Invalid activation function "{}"'.format(activation))
self.T = name_to_mixin[activation]()
self.arn = autoregressive_nn
self.hidden_units = hidden_units
self.logsoftmax = nn.LogSoftmax(dim=-2)
self._cached_log_df_inv_dx = None
self._cached_A = None
self._cached_W_pre = None
self._cached_C = None
self._cached_T_C = None
def _call(self, x):
"""
:param x: the input into the bijection
:type x: torch.Tensor
Invokes the bijection x=>y; in the prototypical context of a
:class:`~pyro.distributions.TransformedDistribution` `x` is a sample from
the base distribution (or the output of a previous transform)
"""
# A, W, b ~ batch_shape x hidden_units x event_shape
A, W_pre, b = self.arn(x)
T = self.T
# Divide the autoregressive output into the component activations
A = F.softplus(A)
C = A * x.unsqueeze(-2) + b
W = F.softmax(W_pre, dim=-2)
T_C = T(C)
D = (W * T_C).sum(dim=-2)
y = T.inv(D)
self._cached_log_df_inv_dx = T.inv.log_abs_det_jacobian(D, y)
self._cached_A = A
self._cached_W_pre = W_pre
self._cached_C = C
self._cached_T_C = T_C
return y
# This method returns log(abs(det(dy/dx)), which is equal to -log(abs(det(dx/dy))
[docs] def log_abs_det_jacobian(self, x, y):
"""
Calculates the elementwise determinant of the log Jacobian
"""
A = self._cached_A
W_pre = self._cached_W_pre
C = self._cached_C
T_C = self._cached_T_C
T = self.T
log_dydD = self._cached_log_df_inv_dx
log_dDdx = torch.logsumexp(
torch.log(A + self.eps)
+ self.logsoftmax(W_pre)
+ T.log_abs_det_jacobian(C, T_C),
dim=-2,
)
log_det = log_dydD + log_dDdx
return log_det.sum(-1)
[docs]@copy_docs_from(ConditionalTransformModule)
class ConditionalNeuralAutoregressive(ConditionalTransformModule):
r"""
An implementation of the deep Neural Autoregressive Flow (NAF) bijective
transform of the "IAF flavour" conditioning on an additiona context variable
that can be used for sampling and scoring samples drawn from it (but not
arbitrary ones).
Example usage:
>>> from pyro.nn import ConditionalAutoRegressiveNN
>>> input_dim = 10
>>> context_dim = 5
>>> batch_size = 3
>>> base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
>>> arn = ConditionalAutoRegressiveNN(input_dim, context_dim, [40],
... param_dims=[16]*3)
>>> transform = ConditionalNeuralAutoregressive(arn, hidden_units=16)
>>> pyro.module("my_transform", transform) # doctest: +SKIP
>>> z = torch.rand(batch_size, context_dim)
>>> flow_dist = dist.ConditionalTransformedDistribution(base_dist,
... [transform]).condition(z)
>>> flow_dist.sample(sample_shape=torch.Size([batch_size])) # doctest: +SKIP
The inverse operation is not implemented. This would require numerical
inversion, e.g., using a root finding method - a possibility for a future
implementation.
:param autoregressive_nn: an autoregressive neural network whose forward call
returns a tuple of three real-valued tensors, whose last dimension is the
input dimension, and whose penultimate dimension is equal to hidden_units.
:type autoregressive_nn: nn.Module
:param hidden_units: the number of hidden units to use in the NAF transformation
(see Eq (8) in reference)
:type hidden_units: int
:param activation: Activation function to use. One of 'ELU', 'LeakyReLU',
'sigmoid', or 'tanh'.
:type activation: string
Reference:
[1] Chin-Wei Huang, David Krueger, Alexandre Lacoste, Aaron Courville. Neural
Autoregressive Flows. [arXiv:1804.00779]
"""
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
def __init__(self, autoregressive_nn, **kwargs):
super().__init__()
self.nn = autoregressive_nn
self.kwargs = kwargs
[docs] def condition(self, context):
"""
Conditions on a context variable, returning a non-conditional transform of
of type :class:`~pyro.distributions.transforms.NeuralAutoregressive`.
"""
# Note that nn.condition doesn't copy the weights of the ConditionalAutoregressiveNN
cond_nn = partial(self.nn, context=context)
cond_nn.permutation = cond_nn.func.permutation
cond_nn.get_permutation = cond_nn.func.get_permutation
return NeuralAutoregressive(cond_nn, **self.kwargs)
[docs]def neural_autoregressive(input_dim, hidden_dims=None, activation="sigmoid", width=16):
"""
A helper function to create a
:class:`~pyro.distributions.transforms.NeuralAutoregressive` object that takes
care of constructing an autoregressive network with the correct input/output
dimensions.
:param input_dim: Dimension of input variable
:type input_dim: int
:param hidden_dims: The desired hidden dimensions of the autoregressive network.
Defaults to using [3*input_dim + 1]
:type hidden_dims: list[int]
:param activation: Activation function to use. One of 'ELU', 'LeakyReLU',
'sigmoid', or 'tanh'.
:type activation: string
:param width: The width of the "multilayer perceptron" in the transform (see
paper). Defaults to 16
:type width: int
"""
if hidden_dims is None:
hidden_dims = [3 * input_dim + 1]
arn = AutoRegressiveNN(input_dim, hidden_dims, param_dims=[width] * 3)
return NeuralAutoregressive(arn, hidden_units=width, activation=activation)
[docs]def conditional_neural_autoregressive(
input_dim, context_dim, hidden_dims=None, activation="sigmoid", width=16
):
"""
A helper function to create a
:class:`~pyro.distributions.transforms.ConditionalNeuralAutoregressive` object
that takes care of constructing an autoregressive network with the correct
input/output dimensions.
:param input_dim: Dimension of input variable
:type input_dim: int
:param context_dim: Dimension of context variable
:type context_dim: int
:param hidden_dims: The desired hidden dimensions of the autoregressive network.
Defaults to using [3*input_dim + 1]
:type hidden_dims: list[int]
:param activation: Activation function to use. One of 'ELU', 'LeakyReLU',
'sigmoid', or 'tanh'.
:type activation: string
:param width: The width of the "multilayer perceptron" in the transform (see
paper). Defaults to 16
:type width: int
"""
if hidden_dims is None:
hidden_dims = [3 * input_dim + 1]
arn = ConditionalAutoRegressiveNN(
input_dim, context_dim, hidden_dims, param_dims=[width] * 3
)
return ConditionalNeuralAutoregressive(
arn, hidden_units=width, activation=activation
)