Source code for pyro.distributions.transforms.spline

# Copyright Contributors to the Pyro project.
# Copyright (c) 2020 Hadi M. Dolatabadi, Sarah Erfani, Christopher Leckie
# Copyright (c) 2019 Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios
# Copyright (c) 2019 Tony Duan
# SPDX-License-Identifier: MIT

# This implementation is adapted in part from:
# * https://github.com/tonyduan/normalizing-flows/blob/master/nf/flows.py; and,
# * https://github.com/hmdolatabadi/LRS_NF/blob/master/nde/transforms/nonlinearities.py,
# under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import constraints

from pyro.distributions.torch_transform import TransformModule
from pyro.distributions.util import copy_docs_from


def _search_sorted(bin_locations, inputs, eps=1e-6):
    """
    Searches for which bin an input belongs to (in a way that is parallelizable and amenable to autodiff)
    """
    bin_locations[..., -1] += eps
    return torch.sum(
        inputs[..., None] >= bin_locations,
        dim=-1
    ) - 1


def _select_bins(x, idx):
    """
    Performs gather to select the bin in the correct way on batched inputs
    """
    idx = idx.clamp(min=0, max=x.size(-1) - 1)
    extra_dims = len(idx.shape) - len(x.shape)
    expanded_x = x.expand(idx.shape[:extra_dims] + (-1,) * len(x.shape))
    return expanded_x.gather(-1, idx).squeeze(-1)


def _calculate_knots(lengths, lower, upper):
    """
    Given a tensor of unscaled bin lengths that sum to 1, plus the lower and upper limits,
    returns the shifted and scaled lengths plus knot positions
    """

    # Cumulative widths gives x (y for inverse) position of knots
    knots = torch.cumsum(lengths, dim=-1)

    # Pad left of last dimension with 1 zero to compensate for dim lost to cumsum
    knots = F.pad(knots, pad=(1, 0), mode='constant', value=0.0)

    # Translate [0,1] knot points to [-B, B]
    knots = (upper - lower) * knots + lower

    # Convert the knot points back to lengths
    # NOTE: Are following two lines a necessary fix for accumulation (round-off) error?
    knots[..., 0] = lower
    knots[..., -1] = upper
    lengths = knots[..., 1:] - knots[..., :-1]

    return lengths, knots


def _monotonic_rational_spline(inputs, widths, heights, derivatives, lambdas,
                               inverse=False,
                               bound=3.,
                               min_bin_width=1e-3,
                               min_bin_height=1e-3,
                               min_derivative=1e-3,
                               min_lambda=0.025):
    """
    Calculating a monotonic rational spline (linear for now) or its inverse, plus the log(abs(detJ)) required
    for normalizing flows.
    NOTE: I omit the docstring with parameter descriptions for this method since it is not considered "public" yet!
    """

    # Ensure bound is positive
    # NOTE: For simplicity, we apply the identity function outside [-B, B] X [-B, B] rather than allowing arbitrary
    # corners to the bounding box. If you want a different bounding box you can apply an affine transform before and
    # after the input
    assert bound > 0.0

    num_bins = widths.shape[-1]
    if min_bin_width * num_bins > 1.0:
        raise ValueError('Minimal bin width too large for the number of bins')
    if min_bin_height * num_bins > 1.0:
        raise ValueError('Minimal bin height too large for the number of bins')

    # inputs, inside_interval_mask, outside_interval_mask ~ (batch_dim, input_dim)
    left, right = -bound, bound
    bottom, top = -bound, bound
    inside_interval_mask = (inputs >= left) & (inputs <= right)
    outside_interval_mask = ~inside_interval_mask

    # outputs, logabsdet ~ (batch_dim, input_dim)
    outputs = torch.zeros_like(inputs)
    logabsdet = torch.zeros_like(inputs)

    # For numerical stability, put lower/upper limits on parameters. E.g .give every bin min_bin_width,
    # then add width fraction of remaining length
    # NOTE: Do this here rather than higher up because we want everything to ensure numerical
    # stability within this function
    widths = min_bin_width + (1.0 - min_bin_width * num_bins) * widths
    heights = min_bin_height + (1.0 - min_bin_height * num_bins) * heights
    derivatives = min_derivative + derivatives
    lambdas = (1 - 2 * min_lambda) * lambdas + min_lambda

    # Cumulative widths are x (y for inverse) position of knots
    # Similarly, cumulative heights are y (x for inverse) position of knots
    widths, cumwidths = _calculate_knots(widths, left, right)
    heights, cumheights = _calculate_knots(heights, bottom, top)

    # Pad left and right derivatives with fixed values at first and last knots
    # These are 1 since the function is the identity outside the bounding box and the derivative is continuous
    # NOTE: Not sure why this is 1.0 - min_derivative rather than 1.0. I've copied this from original implementation
    derivatives = F.pad(derivatives, pad=(1, 1), mode='constant', value=1.0 - min_derivative)

    # Get the index of the bin that each input is in
    # bin_idx ~ (batch_dim, input_dim, 1)
    bin_idx = _search_sorted(cumheights if inverse else cumwidths, inputs)[..., None]

    # Select the value for the relevant bin for the variables used in the main calculation
    input_widths = _select_bins(widths, bin_idx)
    input_cumwidths = _select_bins(cumwidths, bin_idx)
    input_cumheights = _select_bins(cumheights, bin_idx)
    input_delta = _select_bins(heights / widths, bin_idx)
    input_derivatives = _select_bins(derivatives, bin_idx)
    input_derivatives_plus_one = _select_bins(derivatives[..., 1:], bin_idx)
    input_heights = _select_bins(heights, bin_idx)
    input_lambdas = _select_bins(lambdas, bin_idx)

    # The weight, w_a, at the left-hand-side of each bin
    # We are free to choose w_a, so set it to 1
    wa = 1.0

    # The weight, w_b, at the right-hand-side of each bin
    # This turns out to be a multiple of the w_a
    # TODO: Should this be done in log space for numerical stability?
    wb = torch.sqrt(input_derivatives / input_derivatives_plus_one) * wa

    # The weight, w_c, at the division point of each bin
    # Recall that each bin is divided into two parts so we have enough d.o.f. to fit spline
    wc = (input_lambdas * wa * input_derivatives + (1 - input_lambdas) * wb * input_derivatives_plus_one) / input_delta

    # Calculate y coords of bins
    ya = input_cumheights
    yb = input_heights + input_cumheights
    yc = ((1.0 - input_lambdas) * wa * ya + input_lambdas * wb * yb) / ((1.0 - input_lambdas) * wa + input_lambdas * wb)

    # The core monotonic rational spline equation
    if inverse:
        numerator = (input_lambdas * wa * (ya - inputs)) * (inputs <= yc).float() \
            + ((wc - input_lambdas * wb) * inputs + input_lambdas * wb * yb - wc * yc) * (inputs > yc).float()

        denominator = ((wc - wa) * inputs + wa * ya - wc * yc) * (inputs <= yc).float()\
            + ((wc - wb) * inputs + wb * yb - wc * yc) * (inputs > yc).float()

        theta = numerator / denominator

        outputs = theta * input_widths + input_cumwidths

        derivative_numerator = (wa * wc * input_lambdas * (yc - ya) * (inputs <= yc).float()
                                + wb * wc * (1 - input_lambdas) * (yb - yc) * (inputs > yc).float()) * input_widths

        logabsdet = torch.log(derivative_numerator) - 2 * torch.log(torch.abs(denominator))

    else:
        theta = (inputs - input_cumwidths) / input_widths

        numerator = (wa * ya * (input_lambdas - theta) + wc * yc * theta) * (theta <= input_lambdas).float()\
            + (wc * yc * (1 - theta) + wb * yb * (theta - input_lambdas)) * (theta > input_lambdas).float()

        denominator = (wa * (input_lambdas - theta) + wc * theta) * (theta <= input_lambdas).float()\
            + (wc * (1 - theta) + wb * (theta - input_lambdas)) * (theta > input_lambdas).float()

        outputs = numerator / denominator

        derivative_numerator = (wa * wc * input_lambdas * (yc - ya) * (theta <= input_lambdas).float() +
                                wb * wc * (1 - input_lambdas) * (yb - yc) * (theta > input_lambdas).float()) \
            / input_widths

        logabsdet = torch.log(derivative_numerator) - 2 * torch.log(torch.abs(denominator))

    # Apply the identity function outside the bounding box
    outputs[outside_interval_mask] = inputs[outside_interval_mask]
    logabsdet[outside_interval_mask] = 0.0
    return outputs, logabsdet


class SplineLayer(nn.Module):
    """
    Helper class to manage learnable spline. One could imagine this as a standard layer in PyTorch...
    """

    def __init__(self, input_dim, count_bins=8, bound=3., order='linear'):
        super().__init__()

        self.input_dim = input_dim
        self.order = order

        # K rational quadratic segments or 2K rational linear segments...
        self.count_bins = count_bins

        # ...on [-B, B] x [-B, B]
        self.bound = bound

        # Parameters for each dimension
        # TODO: What should the initialization scheme be?
        self.unnormalized_widths = nn.Parameter(torch.randn(self.input_dim, self.count_bins))
        self.unnormalized_heights = nn.Parameter(torch.randn(self.input_dim, self.count_bins))
        self.unnormalized_derivatives = nn.Parameter(torch.randn(self.input_dim, self.count_bins - 1))

        # Rational linear splines have additional lambda parameters
        if self.order == "linear":
            self.unnormalized_lambdas = nn.Parameter(torch.rand(self.input_dim, self.count_bins))
        elif self.order == "quadratic":
            raise ValueError("Monotonic rational quadratic splines not yet implemented!")
        else:
            raise ValueError(
                "Keyword argument 'order' must be one of ['linear', 'quadratic'], but '{}' was found!".format(
                    self.order))

    @property
    def widths(self):
        # widths, unnormalized_widths ~ (input_dim, num_bins)
        return F.softmax(self.unnormalized_widths, dim=-1)

    @property
    def heights(self):
        # heights, unnormalized_heights ~ (input_dim, num_bins)
        return F.softmax(self.unnormalized_heights, dim=-1)

    @property
    def derivatives(self):
        # derivatives, unnormalized_derivatives ~ (input_dim, num_bins-1)
        return F.softplus(self.unnormalized_derivatives)

    @property
    def lambdas(self):
        # lambdas, unnormalized_lambdas ~ (input_dim, num_bins)
        return torch.sigmoid(self.unnormalized_lambdas)

    def __call__(self, x, jacobian=False, **kwargs):
        y, log_detJ = _monotonic_rational_spline(
            x,
            self.widths,
            self.heights,
            self.derivatives,
            self.lambdas,
            bound=self.bound,
            **kwargs)

        if not jacobian:
            return y
        else:
            return y, log_detJ


[docs]@copy_docs_from(TransformModule) class Spline(TransformModule): """ An implementation of the element-wise rational spline bijections of linear and quadratic order (Durkan et al., 2019; Dolatabadi et al., 2020). Rational splines are functions that are comprised of segments that are the ratio of two polynomials. For instance, for the :math:`d`-th dimension and the :math:`k`-th segment on the spline, the function will take the form, :math:`y_d = \\frac{\\alpha^{(k)}(x_d)}{\\beta^{(k)}(x_d)},` where :math:`\\alpha^{(k)}` and :math:`\\beta^{(k)}` are two polynomials of order :math:`d`. For :math:`d=1`, we say that the spline is linear, and for :math:`d=2`, quadratic. The spline is constructed on the specified bounding box, :math:`[-K,K]\\times[-K,K]`, with the identity function used elsewhere . Rational splines offer an excellent combination of functional flexibility whilst maintaining a numerically stable inverse that is of the same computational and space complexities as the forward operation. This element-wise transform permits the accurate represention of complex univariate distributions. Example usage: >>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10)) >>> transform = Spline(10, count_bins=4, bound=3.) >>> pyro.module("my_transform", transform) # doctest: +SKIP >>> flow_dist = dist.TransformedDistribution(base_dist, [transform]) >>> flow_dist.sample() # doctest: +SKIP :param input_dim: Dimension of the input vector. Despite operating element-wise, this is required so we know how many parameters to store. :type input_dim: int :param count_bins: The number of segments comprising the spline. :type count_bins: int :param bound: The quantity :math:`K` determining the bounding box, :math:`[-K,K]\\times[-K,K]`, of the spline. :type bound: float :param order: One of ['linear', 'quadratic'] specifying the order of the spline. :type order: string References: Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios. Neural Spline Flows. NeurIPS 2019. Hadi M. Dolatabadi, Sarah Erfani, Christopher Leckie. Invertible Generative Modeling using Linear Rational Splines. AISTATS 2020. """ domain = constraints.real codomain = constraints.real bijective = True event_dim = 0 def __init__(self, input_dim, count_bins=8, bound=3., order='linear'): super(Spline, self).__init__(cache_size=1) self.layer = SplineLayer(input_dim, count_bins=count_bins, bound=bound, order=order) self._cache_log_detJ = None def _call(self, x): y, log_detJ = self.layer(x, jacobian=True) self._cache_log_detJ = log_detJ return y def _inverse(self, y): """ :param y: the output of the bijection :type y: torch.Tensor Inverts y => x. Uses a previously cached inverse if available, otherwise performs the inversion afresh. """ x, log_detJ = self.layer(y, jacobian=True, inverse=True) self._cache_log_detJ = -log_detJ return x
[docs] def log_abs_det_jacobian(self, x, y): """ Calculates the elementwise determinant of the log jacobian """ x_old, y_old = self._cached_x_y if x is not x_old or y is not y_old: # This call to the parent class Transform will update the cache # as well as calling self._call and recalculating y and log_detJ self(x) return self._cache_log_detJ
[docs]def spline(input_dim, **kwargs): """ A helper function to create a :class:`~pyro.distributions.transforms.Spline` object for consistency with other helpers. :param input_dim: Dimension of input variable :type input_dim: int """ # TODO: A useful heuristic for choosing number of bins from input # dimension like: count_bins=min(5, math.log(input_dim))? return Spline(input_dim, **kwargs)