# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from functools import partial
import torch
from pyro.nn import AutoRegressiveNN, ConditionalAutoRegressiveNN
from .. import constraints
from ..conditional import ConditionalTransformModule
from ..torch_transform import TransformModule
from ..util import copy_docs_from
from .spline import ConditionalSpline
[docs]@copy_docs_from(TransformModule)
class SplineAutoregressive(TransformModule):
r"""
An implementation of the autoregressive layer with rational spline bijections of
linear and quadratic order (Durkan et al., 2019; Dolatabadi et al., 2020).
Rational splines are functions that are comprised of segments that are the ratio
of two polynomials (see :class:`~pyro.distributions.transforms.Spline`).
The autoregressive layer uses the transformation,
:math:`y_d = g_{\theta_d}(x_d)\ \ \ d=1,2,\ldots,D`
where :math:`\mathbf{x}=(x_1,x_2,\ldots,x_D)` are the inputs,
:math:`\mathbf{y}=(y_1,y_2,\ldots,y_D)` are the outputs, :math:`g_{\theta_d}` is
an elementwise rational monotonic spline with parameters :math:`\theta_d`, and
:math:`\theta=(\theta_1,\theta_2,\ldots,\theta_D)` is the output of an
autoregressive NN inputting :math:`\mathbf{x}`.
Example usage:
>>> from pyro.nn import AutoRegressiveNN
>>> input_dim = 10
>>> count_bins = 8
>>> base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
>>> hidden_dims = [input_dim * 10, input_dim * 10]
>>> param_dims = [count_bins, count_bins, count_bins - 1, count_bins]
>>> hypernet = AutoRegressiveNN(input_dim, hidden_dims, param_dims=param_dims)
>>> transform = SplineAutoregressive(input_dim, hypernet, count_bins=count_bins)
>>> pyro.module("my_transform", transform) # doctest: +SKIP
>>> flow_dist = dist.TransformedDistribution(base_dist, [transform])
>>> flow_dist.sample() # doctest: +SKIP
:param input_dim: Dimension of the input vector. Despite operating element-wise,
this is required so we know how many parameters to store.
:type input_dim: int
:param autoregressive_nn: an autoregressive neural network whose forward call
returns tuple of the spline parameters
:type autoregressive_nn: callable
:param count_bins: The number of segments comprising the spline.
:type count_bins: int
:param bound: The quantity :math:`K` determining the bounding box,
:math:`[-K,K]\times[-K,K]`, of the spline.
:type bound: float
:param order: One of ['linear', 'quadratic'] specifying the order of the spline.
:type order: string
References:
Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios. Neural
Spline Flows. NeurIPS 2019.
Hadi M. Dolatabadi, Sarah Erfani, Christopher Leckie. Invertible Generative
Modeling using Linear Rational Splines. AISTATS 2020.
"""
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
autoregressive = True
def __init__(
self, input_dim, autoregressive_nn, count_bins=8, bound=3.0, order="linear"
):
super(SplineAutoregressive, self).__init__(cache_size=1)
self.arn = autoregressive_nn
self.spline = ConditionalSpline(
autoregressive_nn, input_dim, count_bins, bound, order
)
def _call(self, x):
"""
:param x: the input into the bijection
:type x: torch.Tensor
Invokes the bijection x=>y; in the prototypical context of a
:class:`~pyro.distributions.TransformedDistribution` `x` is a sample from
the base distribution (or the output of a previous transform)
"""
spline = self.spline.condition(x)
y = spline(x)
self._cache_log_detJ = spline._cache_log_detJ
return y
def _inverse(self, y):
"""
:param y: the output of the bijection
:type y: torch.Tensor
Inverts y => x. Uses a previously cached inverse if available, otherwise
performs the inversion afresh.
"""
input_dim = y.size(-1)
x = torch.zeros_like(y)
# NOTE: Inversion is an expensive operation that scales in the dimension of the input
for _ in range(input_dim):
spline = self.spline.condition(x)
x = spline._inverse(y)
self._cache_log_detJ = spline._cache_log_detJ
return x
[docs] def log_abs_det_jacobian(self, x, y):
"""
Calculates the elementwise determinant of the log Jacobian
"""
x_old, y_old = self._cached_x_y
if x is not x_old or y is not y_old:
# This call to the parent class Transform will update the cache
# as well as calling self._call and recalculating y and log_detJ
self(x)
return self._cache_log_detJ.sum(-1)
[docs]@copy_docs_from(ConditionalTransformModule)
class ConditionalSplineAutoregressive(ConditionalTransformModule):
r"""
An implementation of the autoregressive layer with rational spline bijections of
linear and quadratic order (Durkan et al., 2019; Dolatabadi et al., 2020) that
conditions on an additional context variable. Rational splines are functions
that are comprised of segments that are the ratio of two polynomials (see
:class:`~pyro.distributions.transforms.Spline`).
The autoregressive layer uses the transformation,
:math:`y_d = g_{\theta_d}(x_d)\ \ \ d=1,2,\ldots,D`
where :math:`\mathbf{x}=(x_1,x_2,\ldots,x_D)` are the inputs,
:math:`\mathbf{y}=(y_1,y_2,\ldots,y_D)` are the outputs, :math:`g_{\theta_d}` is
an elementwise rational monotonic spline with parameters :math:`\theta_d`, and
:math:`\theta=(\theta_1,\theta_2,\ldots,\theta_D)` is the output of a
conditional autoregressive NN inputting :math:`\mathbf{x}` and conditioning on
the context variable :math:`\mathbf{z}`.
Example usage:
>>> from pyro.nn import ConditionalAutoRegressiveNN
>>> input_dim = 10
>>> count_bins = 8
>>> context_dim = 5
>>> batch_size = 3
>>> base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
>>> hidden_dims = [input_dim * 10, input_dim * 10]
>>> param_dims = [count_bins, count_bins, count_bins - 1, count_bins]
>>> hypernet = ConditionalAutoRegressiveNN(input_dim, context_dim, hidden_dims,
... param_dims=param_dims)
>>> transform = ConditionalSplineAutoregressive(input_dim, hypernet,
... count_bins=count_bins)
>>> pyro.module("my_transform", transform) # doctest: +SKIP
>>> z = torch.rand(batch_size, context_dim)
>>> flow_dist = dist.ConditionalTransformedDistribution(base_dist,
... [transform]).condition(z)
>>> flow_dist.sample(sample_shape=torch.Size([batch_size])) # doctest: +SKIP
:param input_dim: Dimension of the input vector. Despite operating element-wise,
this is required so we know how many parameters to store.
:type input_dim: int
:param autoregressive_nn: an autoregressive neural network whose forward call
returns tuple of the spline parameters
:type autoregressive_nn: callable
:param count_bins: The number of segments comprising the spline.
:type count_bins: int
:param bound: The quantity :math:`K` determining the bounding box,
:math:`[-K,K]\times[-K,K]`, of the spline.
:type bound: float
:param order: One of ['linear', 'quadratic'] specifying the order of the spline.
:type order: string
References:
Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios. Neural
Spline Flows. NeurIPS 2019.
Hadi M. Dolatabadi, Sarah Erfani, Christopher Leckie. Invertible Generative
Modeling using Linear Rational Splines. AISTATS 2020.
"""
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
def __init__(self, input_dim, autoregressive_nn, **kwargs):
super().__init__()
self.input_dim = input_dim
self.nn = autoregressive_nn
self.kwargs = kwargs
[docs] def condition(self, context):
"""
Conditions on a context variable, returning a non-conditional transform of
of type :class:`~pyro.distributions.transforms.SplineAutoregressive`.
"""
# Note that nn.condition doesn't copy the weights of the ConditionalAutoregressiveNN
cond_nn = partial(self.nn, context=context)
cond_nn.permutation = cond_nn.func.permutation
cond_nn.get_permutation = cond_nn.func.get_permutation
return SplineAutoregressive(self.input_dim, cond_nn, **self.kwargs)
[docs]def spline_autoregressive(
input_dim, hidden_dims=None, count_bins=8, bound=3.0, order="linear"
):
r"""
A helper function to create an
:class:`~pyro.distributions.transforms.SplineAutoregressive` object that takes
care of constructing an autoregressive network with the correct input/output
dimensions.
:param input_dim: Dimension of input variable
:type input_dim: int
:param hidden_dims: The desired hidden dimensions of the autoregressive network.
Defaults to using [3*input_dim + 1]
:type hidden_dims: list[int]
:param count_bins: The number of segments comprising the spline.
:type count_bins: int
:param bound: The quantity :math:`K` determining the bounding box,
:math:`[-K,K]\times[-K,K]`, of the spline.
:type bound: float
:param order: One of ['linear', 'quadratic'] specifying the order of the spline.
:type order: string
"""
if hidden_dims is None:
hidden_dims = [input_dim * 10, input_dim * 10]
param_dims = [count_bins, count_bins, count_bins - 1, count_bins]
arn = AutoRegressiveNN(input_dim, hidden_dims, param_dims=param_dims)
return SplineAutoregressive(
input_dim, arn, count_bins=count_bins, bound=bound, order=order
)
[docs]def conditional_spline_autoregressive(
input_dim, context_dim, hidden_dims=None, count_bins=8, bound=3.0, order="linear"
):
r"""
A helper function to create a
:class:`~pyro.distributions.transforms.ConditionalSplineAutoregressive` object
that takes care of constructing an autoregressive network with the correct
input/output dimensions.
:param input_dim: Dimension of input variable
:type input_dim: int
:param context_dim: Dimension of context variable
:type context_dim: int
:param hidden_dims: The desired hidden dimensions of the autoregressive network.
Defaults to using [input_dim * 10, input_dim * 10]
:type hidden_dims: list[int]
:param count_bins: The number of segments comprising the spline.
:type count_bins: int
:param bound: The quantity :math:`K` determining the bounding box,
:math:`[-K,K]\times[-K,K]`, of the spline.
:type bound: float
:param order: One of ['linear', 'quadratic'] specifying the order of the spline.
:type order: string
"""
if hidden_dims is None:
hidden_dims = [input_dim * 10, input_dim * 10]
param_dims = [count_bins, count_bins, count_bins - 1, count_bins]
arn = ConditionalAutoRegressiveNN(
input_dim, context_dim, hidden_dims, param_dims=param_dims
)
return ConditionalSplineAutoregressive(
input_dim, arn, count_bins=count_bins, bound=bound, order=order
)