from functools import partial

import torch

from pyro.nn import AutoRegressiveNN, ConditionalAutoRegressiveNN

from .. import constraints
from ..conditional import ConditionalTransformModule
from ..torch_transform import TransformModule
from ..util import copy_docs_from
from .spline import ConditionalSpline

[docs]@copy_docs_from(TransformModule)
class SplineAutoregressive(TransformModule):
r"""
An implementation of the autoregressive layer with rational spline bijections of
linear and quadratic order (Durkan et al., 2019; Dolatabadi et al., 2020).
Rational splines are functions that are comprised of segments that are the ratio
of two polynomials (see :class:~pyro.distributions.transforms.Spline).

The autoregressive layer uses the transformation,

:math:y_d = g_{\theta_d}(x_d)\ \ \ d=1,2,\ldots,D

where :math:\mathbf{x}=(x_1,x_2,\ldots,x_D) are the inputs,
:math:\mathbf{y}=(y_1,y_2,\ldots,y_D) are the outputs, :math:g_{\theta_d} is
an elementwise rational monotonic spline with parameters :math:\theta_d, and
:math:\theta=(\theta_1,\theta_2,\ldots,\theta_D) is the output of an
autoregressive NN inputting :math:\mathbf{x}.

Example usage:

>>> from pyro.nn import AutoRegressiveNN
>>> input_dim = 10
>>> count_bins = 8
>>> base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
>>> hidden_dims = [input_dim * 10, input_dim * 10]
>>> param_dims = [count_bins, count_bins, count_bins - 1, count_bins]
>>> hypernet = AutoRegressiveNN(input_dim, hidden_dims, param_dims=param_dims)
>>> transform = SplineAutoregressive(input_dim, hypernet, count_bins=count_bins)
>>> pyro.module("my_transform", transform)  # doctest: +SKIP
>>> flow_dist = dist.TransformedDistribution(base_dist, [transform])
>>> flow_dist.sample()  # doctest: +SKIP

:param input_dim: Dimension of the input vector. Despite operating element-wise,
this is required so we know how many parameters to store.
:type input_dim: int
:param autoregressive_nn: an autoregressive neural network whose forward call
returns tuple of the spline parameters
:type autoregressive_nn: callable
:param count_bins: The number of segments comprising the spline.
:type count_bins: int
:param bound: The quantity :math:K determining the bounding box,
:math:[-K,K]\times[-K,K], of the spline.
:type bound: float
:param order: One of ['linear', 'quadratic'] specifying the order of the spline.
:type order: string

"""

domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
autoregressive = True

def __init__(
self, input_dim, autoregressive_nn, count_bins=8, bound=3.0, order="linear"
):
super(SplineAutoregressive, self).__init__(cache_size=1)
self.arn = autoregressive_nn
self.spline = ConditionalSpline(
autoregressive_nn, input_dim, count_bins, bound, order
)

def _call(self, x):
"""
:param x: the input into the bijection
:type x: torch.Tensor

Invokes the bijection x=>y; in the prototypical context of a
:class:~pyro.distributions.TransformedDistribution x is a sample from
the base distribution (or the output of a previous transform)
"""
spline = self.spline.condition(x)
y = spline(x)
self._cache_log_detJ = spline._cache_log_detJ
return y

def _inverse(self, y):
"""
:param y: the output of the bijection
:type y: torch.Tensor

Inverts y => x. Uses a previously cached inverse if available, otherwise
performs the inversion afresh.
"""
input_dim = y.size(-1)
x = torch.zeros_like(y)

# NOTE: Inversion is an expensive operation that scales in the dimension of the input
for _ in range(input_dim):
spline = self.spline.condition(x)
x = spline._inverse(y)

self._cache_log_detJ = spline._cache_log_detJ
return x

[docs]    def log_abs_det_jacobian(self, x, y):
"""
Calculates the elementwise determinant of the log Jacobian
"""
x_old, y_old = self._cached_x_y
if x is not x_old or y is not y_old:
# This call to the parent class Transform will update the cache
# as well as calling self._call and recalculating y and log_detJ
self(x)

return self._cache_log_detJ.sum(-1)

[docs]@copy_docs_from(ConditionalTransformModule)
class ConditionalSplineAutoregressive(ConditionalTransformModule):
r"""
An implementation of the autoregressive layer with rational spline bijections of
linear and quadratic order (Durkan et al., 2019; Dolatabadi et al., 2020) that
conditions on an additional context variable. Rational splines are functions
that are comprised of segments that are the ratio of two polynomials (see
:class:~pyro.distributions.transforms.Spline).

The autoregressive layer uses the transformation,

:math:y_d = g_{\theta_d}(x_d)\ \ \ d=1,2,\ldots,D

where :math:\mathbf{x}=(x_1,x_2,\ldots,x_D) are the inputs,
:math:\mathbf{y}=(y_1,y_2,\ldots,y_D) are the outputs, :math:g_{\theta_d} is
an elementwise rational monotonic spline with parameters :math:\theta_d, and
:math:\theta=(\theta_1,\theta_2,\ldots,\theta_D) is the output of a
conditional autoregressive NN inputting :math:\mathbf{x} and conditioning on
the context variable :math:\mathbf{z}.

Example usage:

>>> from pyro.nn import ConditionalAutoRegressiveNN
>>> input_dim = 10
>>> count_bins = 8
>>> context_dim = 5
>>> batch_size = 3
>>> base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
>>> hidden_dims = [input_dim * 10, input_dim * 10]
>>> param_dims = [count_bins, count_bins, count_bins - 1, count_bins]
>>> hypernet = ConditionalAutoRegressiveNN(input_dim, context_dim, hidden_dims,
... param_dims=param_dims)
>>> transform = ConditionalSplineAutoregressive(input_dim, hypernet,
... count_bins=count_bins)
>>> pyro.module("my_transform", transform)  # doctest: +SKIP
>>> z = torch.rand(batch_size, context_dim)
>>> flow_dist = dist.ConditionalTransformedDistribution(base_dist,
... [transform]).condition(z)
>>> flow_dist.sample(sample_shape=torch.Size([batch_size]))  # doctest: +SKIP

:param input_dim: Dimension of the input vector. Despite operating element-wise,
this is required so we know how many parameters to store.
:type input_dim: int
:param autoregressive_nn: an autoregressive neural network whose forward call
returns tuple of the spline parameters
:type autoregressive_nn: callable
:param count_bins: The number of segments comprising the spline.
:type count_bins: int
:param bound: The quantity :math:K determining the bounding box,
:math:[-K,K]\times[-K,K], of the spline.
:type bound: float
:param order: One of ['linear', 'quadratic'] specifying the order of the spline.
:type order: string

"""

domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True

def __init__(self, input_dim, autoregressive_nn, **kwargs):
super().__init__()
self.input_dim = input_dim
self.nn = autoregressive_nn
self.kwargs = kwargs

[docs]    def condition(self, context):
"""
Conditions on a context variable, returning a non-conditional transform of
of type :class:~pyro.distributions.transforms.SplineAutoregressive.
"""

# Note that nn.condition doesn't copy the weights of the ConditionalAutoregressiveNN
cond_nn = partial(self.nn, context=context)
cond_nn.permutation = cond_nn.func.permutation
cond_nn.get_permutation = cond_nn.func.get_permutation
return SplineAutoregressive(self.input_dim, cond_nn, **self.kwargs)

[docs]def spline_autoregressive(
input_dim, hidden_dims=None, count_bins=8, bound=3.0, order="linear"
):
r"""
A helper function to create an
:class:~pyro.distributions.transforms.SplineAutoregressive object that takes
care of constructing an autoregressive network with the correct input/output
dimensions.

:param input_dim: Dimension of input variable
:type input_dim: int
:param hidden_dims: The desired hidden dimensions of the autoregressive network.
Defaults to using [3*input_dim + 1]
:type hidden_dims: list[int]
:param count_bins: The number of segments comprising the spline.
:type count_bins: int
:param bound: The quantity :math:K determining the bounding box,
:math:[-K,K]\times[-K,K], of the spline.
:type bound: float
:param order: One of ['linear', 'quadratic'] specifying the order of the spline.
:type order: string

"""

if hidden_dims is None:
hidden_dims = [input_dim * 10, input_dim * 10]

param_dims = [count_bins, count_bins, count_bins - 1, count_bins]
arn = AutoRegressiveNN(input_dim, hidden_dims, param_dims=param_dims)
return SplineAutoregressive(
input_dim, arn, count_bins=count_bins, bound=bound, order=order
)

[docs]def conditional_spline_autoregressive(
input_dim, context_dim, hidden_dims=None, count_bins=8, bound=3.0, order="linear"
):
r"""
A helper function to create a
:class:~pyro.distributions.transforms.ConditionalSplineAutoregressive object
that takes care of constructing an autoregressive network with the correct
input/output dimensions.

:param input_dim: Dimension of input variable
:type input_dim: int
:param context_dim: Dimension of context variable
:type context_dim: int
:param hidden_dims: The desired hidden dimensions of the autoregressive network.
Defaults to using [input_dim * 10, input_dim * 10]
:type hidden_dims: list[int]
:param count_bins: The number of segments comprising the spline.
:type count_bins: int
:param bound: The quantity :math:K determining the bounding box,
:math:[-K,K]\times[-K,K], of the spline.
:type bound: float
:param order: One of ['linear', 'quadratic'] specifying the order of the spline.
:type order: string

"""

if hidden_dims is None:
hidden_dims = [input_dim * 10, input_dim * 10]

param_dims = [count_bins, count_bins, count_bins - 1, count_bins]
arn = ConditionalAutoRegressiveNN(
input_dim, context_dim, hidden_dims, param_dims=param_dims
)
return ConditionalSplineAutoregressive(
input_dim, arn, count_bins=count_bins, bound=bound, order=order
)