# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
from functools import partial
import torch
import torch.nn as nn
from pyro.nn import AutoRegressiveNN, ConditionalAutoRegressiveNN
from .. import constraints
from ..conditional import ConditionalTransformModule
from ..torch_transform import TransformModule
from ..util import copy_docs_from
from .utils import clamp_preserve_gradients
[docs]@copy_docs_from(TransformModule)
class AffineAutoregressive(TransformModule):
r"""
An implementation of the bijective transform of Inverse Autoregressive Flow
(IAF), using by default Eq (10) from Kingma Et Al., 2016,
:math:`\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}`
where :math:`\mathbf{x}` are the inputs, :math:`\mathbf{y}` are the outputs,
:math:`\mu_t,\sigma_t` are calculated from an autoregressive network on
:math:`\mathbf{x}`, and :math:`\sigma_t>0`.
If the stable keyword argument is set to True then the transformation used is,
:math:`\mathbf{y} = \sigma_t\odot\mathbf{x} + (1-\sigma_t)\odot\mu_t`
where :math:`\sigma_t` is restricted to :math:`(0,1)`. This variant of IAF is
claimed by the authors to be more numerically stable than one using Eq (10),
although in practice it leads to a restriction on the distributions that can be
represented, presumably since the input is restricted to rescaling by a number
on :math:`(0,1)`.
Together with :class:`~pyro.distributions.TransformedDistribution` this provides
a way to create richer variational approximations.
Example usage:
>>> from pyro.nn import AutoRegressiveNN
>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
>>> transform = AffineAutoregressive(AutoRegressiveNN(10, [40]))
>>> pyro.module("my_transform", transform) # doctest: +SKIP
>>> flow_dist = dist.TransformedDistribution(base_dist, [transform])
>>> flow_dist.sample() # doctest: +SKIP
The inverse of the Bijector is required when, e.g., scoring the log density of a
sample with :class:`~pyro.distributions.TransformedDistribution`. This
implementation caches the inverse of the Bijector when its forward operation is
called, e.g., when sampling from
:class:`~pyro.distributions.TransformedDistribution`. However, if the cached
value isn't available, either because it was overwritten during sampling a new
value or an arbitrary value is being scored, it will calculate it manually. Note
that this is an operation that scales as O(D) where D is the input dimension,
and so should be avoided for large dimensional uses. So in general, it is cheap
to sample from IAF and score a value that was sampled by IAF, but expensive to
score an arbitrary value.
:param autoregressive_nn: an autoregressive neural network whose forward call
returns a real-valued mean and logit-scale as a tuple
:type autoregressive_nn: callable
:param log_scale_min_clip: The minimum value for clipping the log(scale) from
the autoregressive NN
:type log_scale_min_clip: float
:param log_scale_max_clip: The maximum value for clipping the log(scale) from
the autoregressive NN
:type log_scale_max_clip: float
:param sigmoid_bias: A term to add the logit of the input when using the stable
tranform.
:type sigmoid_bias: float
:param stable: When true, uses the alternative "stable" version of the transform
(see above).
:type stable: bool
References:
[1] Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever,
Max Welling. Improving Variational Inference with Inverse Autoregressive Flow.
[arXiv:1606.04934]
[2] Danilo Jimenez Rezende, Shakir Mohamed. Variational Inference with
Normalizing Flows. [arXiv:1505.05770]
[3] Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. MADE: Masked
Autoencoder for Distribution Estimation. [arXiv:1502.03509]
"""
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
sign = +1
autoregressive = True
def __init__(
self,
autoregressive_nn,
log_scale_min_clip=-5.0,
log_scale_max_clip=3.0,
sigmoid_bias=2.0,
stable=False,
):
super().__init__(cache_size=1)
self.arn = autoregressive_nn
self._cached_log_scale = None
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
self.sigmoid = nn.Sigmoid()
self.logsigmoid = nn.LogSigmoid()
self.sigmoid_bias = sigmoid_bias
self.stable = stable
if stable:
self._call = self._call_stable
self._inverse = self._inverse_stable
def _call(self, x):
"""
:param x: the input into the bijection
:type x: torch.Tensor
Invokes the bijection x=>y; in the prototypical context of a
:class:`~pyro.distributions.TransformedDistribution` `x` is a sample from
the base distribution (or the output of a previous transform)
"""
mean, log_scale = self.arn(x)
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
self._cached_log_scale = log_scale
scale = torch.exp(log_scale)
y = scale * x + mean
return y
def _inverse(self, y):
"""
:param y: the output of the bijection
:type y: torch.Tensor
Inverts y => x. Uses a previously cached inverse if available, otherwise
performs the inversion afresh.
"""
x_size = y.size()[:-1]
perm = self.arn.permutation
input_dim = y.size(-1)
x = [torch.zeros(x_size, device=y.device)] * input_dim
# NOTE: Inversion is an expensive operation that scales in the dimension of the input
for idx in perm:
mean, log_scale = self.arn(torch.stack(x, dim=-1))
inverse_scale = torch.exp(
-clamp_preserve_gradients(
log_scale[..., idx],
min=self.log_scale_min_clip,
max=self.log_scale_max_clip,
)
)
mean = mean[..., idx]
x[idx] = (y[..., idx] - mean) * inverse_scale
x = torch.stack(x, dim=-1)
log_scale = clamp_preserve_gradients(
log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip
)
self._cached_log_scale = log_scale
return x
[docs] def log_abs_det_jacobian(self, x, y):
"""
Calculates the elementwise determinant of the log Jacobian
"""
x_old, y_old = self._cached_x_y
if x is not x_old or y is not y_old:
# This call to the parent class Transform will update the cache
# as well as calling self._call and recalculating y and log_detJ
self(x)
if self._cached_log_scale is not None:
log_scale = self._cached_log_scale
elif not self.stable:
_, log_scale = self.arn(x)
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
else:
_, logit_scale = self.arn(x)
log_scale = self.logsigmoid(logit_scale + self.sigmoid_bias)
return log_scale.sum(-1)
def _call_stable(self, x):
"""
:param x: the input into the bijection
:type x: torch.Tensor
Invokes the bijection x=>y; in the prototypical context of a
:class:`~pyro.distributions.TransformedDistribution` `x` is a sample from
the base distribution (or the output of a previous transform)
"""
mean, logit_scale = self.arn(x)
logit_scale = logit_scale + self.sigmoid_bias
scale = self.sigmoid(logit_scale)
log_scale = self.logsigmoid(logit_scale)
self._cached_log_scale = log_scale
y = scale * x + (1 - scale) * mean
return y
def _inverse_stable(self, y):
"""
:param y: the output of the bijection
:type y: torch.Tensor
Inverts y => x.
"""
x_size = y.size()[:-1]
perm = self.arn.permutation
input_dim = y.size(-1)
x = [torch.zeros(x_size, device=y.device)] * input_dim
# NOTE: Inversion is an expensive operation that scales in the dimension of the input
for idx in perm:
mean, logit_scale = self.arn(torch.stack(x, dim=-1))
inverse_scale = 1 + torch.exp(-logit_scale[..., idx] - self.sigmoid_bias)
x[idx] = inverse_scale * y[..., idx] + (1 - inverse_scale) * mean[..., idx]
self._cached_log_scale = self.logsigmoid(logit_scale + self.sigmoid_bias)
x = torch.stack(x, dim=-1)
return x
[docs]@copy_docs_from(ConditionalTransformModule)
class ConditionalAffineAutoregressive(ConditionalTransformModule):
r"""
An implementation of the bijective transform of Inverse Autoregressive Flow
(IAF) that conditions on an additional context variable and uses, by default,
Eq (10) from Kingma Et Al., 2016,
:math:`\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}`
where :math:`\mathbf{x}` are the inputs, :math:`\mathbf{y}` are the outputs,
:math:`\mu_t,\sigma_t` are calculated from an autoregressive network on
:math:`\mathbf{x}` and context :math:`\mathbf{z}\in\mathbb{R}^M`, and
:math:`\sigma_t>0`.
If the stable keyword argument is set to True then the transformation used is,
:math:`\mathbf{y} = \sigma_t\odot\mathbf{x} + (1-\sigma_t)\odot\mu_t`
where :math:`\sigma_t` is restricted to :math:`(0,1)`. This variant of IAF is
claimed by the authors to be more numerically stable than one using Eq (10),
although in practice it leads to a restriction on the distributions that can be
represented, presumably since the input is restricted to rescaling by a number
on :math:`(0,1)`.
Together with :class:`~pyro.distributions.ConditionalTransformedDistribution`
this provides a way to create richer variational approximations.
Example usage:
>>> from pyro.nn import ConditionalAutoRegressiveNN
>>> input_dim = 10
>>> context_dim = 4
>>> batch_size = 3
>>> hidden_dims = [10*input_dim, 10*input_dim]
>>> base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
>>> hypernet = ConditionalAutoRegressiveNN(input_dim, context_dim, hidden_dims)
>>> transform = ConditionalAffineAutoregressive(hypernet)
>>> pyro.module("my_transform", transform) # doctest: +SKIP
>>> z = torch.rand(batch_size, context_dim)
>>> flow_dist = dist.ConditionalTransformedDistribution(base_dist,
... [transform]).condition(z)
>>> flow_dist.sample(sample_shape=torch.Size([batch_size])) # doctest: +SKIP
The inverse of the Bijector is required when, e.g., scoring the log density of a
sample with :class:`~pyro.distributions.TransformedDistribution`. This
implementation caches the inverse of the Bijector when its forward operation is
called, e.g., when sampling from
:class:`~pyro.distributions.TransformedDistribution`. However, if the cached
value isn't available, either because it was overwritten during sampling a new
value or an arbitrary value is being scored, it will calculate it manually. Note
that this is an operation that scales as O(D) where D is the input dimension,
and so should be avoided for large dimensional uses. So in general, it is cheap
to sample from IAF and score a value that was sampled by IAF, but expensive to
score an arbitrary value.
:param autoregressive_nn: an autoregressive neural network whose forward call
returns a real-valued mean and logit-scale as a tuple
:type autoregressive_nn: nn.Module
:param log_scale_min_clip: The minimum value for clipping the log(scale) from
the autoregressive NN
:type log_scale_min_clip: float
:param log_scale_max_clip: The maximum value for clipping the log(scale) from
the autoregressive NN
:type log_scale_max_clip: float
:param sigmoid_bias: A term to add the logit of the input when using the stable
tranform.
:type sigmoid_bias: float
:param stable: When true, uses the alternative "stable" version of the transform
(see above).
:type stable: bool
References:
[1] Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever,
Max Welling. Improving Variational Inference with Inverse Autoregressive Flow.
[arXiv:1606.04934]
[2] Danilo Jimenez Rezende, Shakir Mohamed. Variational Inference with
Normalizing Flows. [arXiv:1505.05770]
[3] Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. MADE: Masked
Autoencoder for Distribution Estimation. [arXiv:1502.03509]
"""
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
def __init__(self, autoregressive_nn, **kwargs):
super().__init__()
self.nn = autoregressive_nn
self.kwargs = kwargs
[docs] def condition(self, context):
"""
Conditions on a context variable, returning a non-conditional transform of
of type :class:`~pyro.distributions.transforms.AffineAutoregressive`.
"""
cond_nn = partial(self.nn, context=context)
cond_nn.permutation = cond_nn.func.permutation
cond_nn.get_permutation = cond_nn.func.get_permutation
return AffineAutoregressive(cond_nn, **self.kwargs)
[docs]def affine_autoregressive(input_dim, hidden_dims=None, **kwargs):
"""
A helper function to create an
:class:`~pyro.distributions.transforms.AffineAutoregressive` object that takes
care of constructing an autoregressive network with the correct input/output
dimensions.
:param input_dim: Dimension of input variable
:type input_dim: int
:param hidden_dims: The desired hidden dimensions of the autoregressive network.
Defaults to using [3*input_dim + 1]
:type hidden_dims: list[int]
:param log_scale_min_clip: The minimum value for clipping the log(scale) from
the autoregressive NN
:type log_scale_min_clip: float
:param log_scale_max_clip: The maximum value for clipping the log(scale) from
the autoregressive NN
:type log_scale_max_clip: float
:param sigmoid_bias: A term to add the logit of the input when using the stable
tranform.
:type sigmoid_bias: float
:param stable: When true, uses the alternative "stable" version of the transform
(see above).
:type stable: bool
"""
if hidden_dims is None:
hidden_dims = [3 * input_dim + 1]
arn = AutoRegressiveNN(input_dim, hidden_dims)
return AffineAutoregressive(arn, **kwargs)
[docs]def conditional_affine_autoregressive(
input_dim, context_dim, hidden_dims=None, **kwargs
):
"""
A helper function to create an
:class:`~pyro.distributions.transforms.ConditionalAffineAutoregressive` object
that takes care of constructing a dense network with the correct input/output
dimensions.
:param input_dim: Dimension of input variable
:type input_dim: int
:param context_dim: Dimension of context variable
:type context_dim: int
:param hidden_dims: The desired hidden dimensions of the dense network. Defaults
to using [10*input_dim]
:type hidden_dims: list[int]
:param log_scale_min_clip: The minimum value for clipping the log(scale) from
the autoregressive NN
:type log_scale_min_clip: float
:param log_scale_max_clip: The maximum value for clipping the log(scale) from
the autoregressive NN
:type log_scale_max_clip: float
:param sigmoid_bias: A term to add the logit of the input when using the stable
tranform.
:type sigmoid_bias: float
:param stable: When true, uses the alternative "stable" version of the transform
(see above).
:type stable: bool
"""
if hidden_dims is None:
hidden_dims = [10 * input_dim]
nn = ConditionalAutoRegressiveNN(input_dim, context_dim, hidden_dims)
return ConditionalAffineAutoregressive(nn, **kwargs)