Source code for pyro.distributions.transforms.spline_coupling

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

from pyro.nn import DenseNN

from .. import constraints
from ..torch_transform import TransformModule
from ..util import copy_docs_from
from .spline import ConditionalSpline, Spline


[docs]@copy_docs_from(TransformModule) class SplineCoupling(TransformModule): r""" An implementation of the coupling layer with rational spline bijections of linear and quadratic order (Durkan et al., 2019; Dolatabadi et al., 2020). Rational splines are functions that are comprised of segments that are the ratio of two polynomials (see :class:`~pyro.distributions.transforms.Spline`). The spline coupling layer uses the transformation, :math:`\mathbf{y}_{1:d} = g_\theta(\mathbf{x}_{1:d})` :math:`\mathbf{y}_{(d+1):D} = h_\phi(\mathbf{x}_{(d+1):D};\mathbf{x}_{1:d})` where :math:`\mathbf{x}` are the inputs, :math:`\mathbf{y}` are the outputs, e.g. :math:`\mathbf{x}_{1:d}` represents the first :math:`d` elements of the inputs, :math:`g_\theta` is either the identity function or an elementwise rational monotonic spline with parameters :math:`\theta`, and :math:`h_\phi` is a conditional elementwise spline spline, conditioning on the first :math:`d` elements. Example usage: >>> from pyro.nn import DenseNN >>> input_dim = 10 >>> split_dim = 6 >>> count_bins = 8 >>> base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim)) >>> param_dims = [(input_dim - split_dim) * count_bins, ... (input_dim - split_dim) * count_bins, ... (input_dim - split_dim) * (count_bins - 1), ... (input_dim - split_dim) * count_bins] >>> hypernet = DenseNN(split_dim, [10*input_dim], param_dims) >>> transform = SplineCoupling(input_dim, split_dim, hypernet) >>> pyro.module("my_transform", transform) # doctest: +SKIP >>> flow_dist = dist.TransformedDistribution(base_dist, [transform]) >>> flow_dist.sample() # doctest: +SKIP :param input_dim: Dimension of the input vector. Despite operating element-wise, this is required so we know how many parameters to store. :type input_dim: int :param split_dim: Zero-indexed dimension :math:`d` upon which to perform input/ output split for transformation. :param hypernet: a neural network whose forward call returns a tuple of spline parameters (see :class:`~pyro.distributions.transforms.ConditionalSpline`). :type hypernet: callable :param count_bins: The number of segments comprising the spline. :type count_bins: int :param bound: The quantity :math:`K` determining the bounding box, :math:`[-K,K]\times[-K,K]`, of the spline. :type bound: float :param order: One of ['linear', 'quadratic'] specifying the order of the spline. :type order: string References: Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios. Neural Spline Flows. NeurIPS 2019. Hadi M. Dolatabadi, Sarah Erfani, Christopher Leckie. Invertible Generative Modeling using Linear Rational Splines. AISTATS 2020. """ domain = constraints.real_vector codomain = constraints.real_vector bijective = True def __init__( self, input_dim, split_dim, hypernet, count_bins=8, bound=3.0, order="linear", identity=False, ): super(SplineCoupling, self).__init__(cache_size=1) # One part of the input is (optionally) put through an element-wise spline and the other part through a # conditional one that inputs the first part. self.lower_spline = Spline(split_dim, count_bins, bound, order) self.upper_spline = ConditionalSpline( hypernet, input_dim - split_dim, count_bins, bound, order ) self.split_dim = split_dim self.identity = identity def _call(self, x): """ :param x: the input into the bijection :type x: torch.Tensor Invokes the bijection x=>y; in the prototypical context of a :class:`~pyro.distributions.TransformedDistribution` `x` is a sample from the base distribution (or the output of a previous transform) """ x1, x2 = x[..., : self.split_dim], x[..., self.split_dim :] if not self.identity: y1 = self.lower_spline(x1) log_detK = self.lower_spline._cache_log_detJ else: y1 = x1 upper_spline = self.upper_spline.condition(x1) y2 = upper_spline(x2) log_detJ = upper_spline._cache_log_detJ if not self.identity: log_detJ = torch.cat([log_detJ, log_detK], dim=-1) self._cache_log_detJ = log_detJ return torch.cat([y1, y2], dim=-1) def _inverse(self, y): """ :param y: the output of the bijection :type y: torch.Tensor Inverts y => x. Uses a previously cached inverse if available, otherwise performs the inversion afresh. """ y1, y2 = y[..., : self.split_dim], y[..., self.split_dim :] if not self.identity: x1 = self.lower_spline._inv_call(y1) log_detK = self.lower_spline._cache_log_detJ else: x1 = y1 upper_spline = self.upper_spline.condition(x1) x2 = upper_spline._inv_call(y2) log_detJ = upper_spline._cache_log_detJ if not self.identity: log_detJ = torch.cat([log_detJ, log_detK], dim=-1) self._cache_log_detJ = log_detJ return torch.cat([x1, x2], dim=-1)
[docs] def log_abs_det_jacobian(self, x, y): """ Calculates the elementwise determinant of the log jacobian """ x_old, y_old = self._cached_x_y if x is not x_old or y is not y_old: # This call to the parent class Transform will update the cache # as well as calling self._call and recalculating y and log_detJ self(x) return self._cache_log_detJ.sum(-1)
[docs]def spline_coupling( input_dim, split_dim=None, hidden_dims=None, count_bins=8, bound=3.0 ): """ A helper function to create a :class:`~pyro.distributions.transforms.SplineCoupling` object for consistency with other helpers. :param input_dim: Dimension of input variable :type input_dim: int """ if split_dim is None: split_dim = input_dim // 2 if hidden_dims is None: hidden_dims = [input_dim * 10, input_dim * 10] nn = DenseNN( split_dim, hidden_dims, param_dims=[ (input_dim - split_dim) * count_bins, (input_dim - split_dim) * count_bins, (input_dim - split_dim) * (count_bins - 1), (input_dim - split_dim) * count_bins, ], ) return SplineCoupling(input_dim, split_dim, nn, count_bins, bound)