Source code for pyro.distributions.stable

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import math

import torch
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all

from pyro.distributions.stable_log_prob import _stable_log_prob
from pyro.distributions.torch_distribution import TorchDistribution


def _unsafe_standard_stable(alpha, beta, V, W, coords):
    # Implements a noisily reparametrized version of the sampler
    # Chambers-Mallows-Stuck method as corrected by Weron [1,3] and simplified
    # by Nolan [4]. This will fail if alpha is close to 1.

    # Differentiably transform noise via parameters.
    assert V.shape == W.shape
    inv_alpha = alpha.reciprocal()
    half_pi = math.pi / 2
    eps = torch.finfo(V.dtype).eps
    # make V belong to the open interval (-pi/2, pi/2)
    V = V.clamp(min=2 * eps - half_pi, max=half_pi - 2 * eps)
    ha = half_pi * alpha
    b = beta * ha.tan()
    # +/- `ha` term to keep the precision of alpha * (V + half_pi) when V ~ -half_pi
    v = b.atan() - ha + alpha * (V + half_pi)
    Z = (
        v.sin()
        / ((1 + b * b).rsqrt() * V.cos()).pow(inv_alpha)
        * ((v - V).cos().clamp(min=eps) / W).pow(inv_alpha - 1)
    )
    Z.data[Z.data != Z.data] = 0  # drop occasional NANs

    # Optionally convert to Nolan's parametrization S^0 where samples depend
    # continuously on (alpha,beta), allowing interpolation around the hole at
    # alpha=1.
    if coords == "S0":
        return Z - b
    elif coords == "S":
        return Z
    else:
        raise ValueError("Unknown coords: {}".format(coords))


RADIUS = 0.01


def _standard_stable(alpha, beta, aux_uniform, aux_exponential, coords):
    """
    Differentiably transform two random variables::

        aux_uniform ~ Uniform(-pi/2, pi/2)
        aux_exponential ~ Exponential(1)

    to a standard ``Stable(alpha, beta)`` random variable.
    """
    # Determine whether a hole workaround is needed.
    with torch.no_grad():
        hole = 1.0
        near_hole = (alpha - hole).abs() <= RADIUS
    if not torch._C._get_tracing_state() and not near_hole.any():
        return _unsafe_standard_stable(
            alpha, beta, aux_uniform, aux_exponential, coords=coords
        )
    if coords == "S":
        # S coords are discontinuous, so interpolate instead in S0 coords.
        Z = _standard_stable(alpha, beta, aux_uniform, aux_exponential, "S0")
        return torch.where(alpha == 1, Z, Z + beta * (math.pi / 2 * alpha).tan())

    # Avoid the hole at alpha=1 by interpolating between pairs
    # of points at hole-RADIUS and hole+RADIUS.
    aux_uniform_ = aux_uniform.unsqueeze(-1)
    aux_exponential_ = aux_exponential.unsqueeze(-1)
    beta_ = beta.unsqueeze(-1)
    alpha_ = alpha.unsqueeze(-1).expand(alpha.shape + (2,)).contiguous()
    with torch.no_grad():
        lower, upper = alpha_.unbind(-1)
        lower.data[near_hole] = hole - RADIUS
        upper.data[near_hole] = hole + RADIUS
        # We don't need to backprop through weights, since we've pretended
        # alpha_ is reparametrized, even though we've clamped some values.
        #               |a - a'|
        # weight = 1 - ----------
        #              2 * RADIUS
        weights = (alpha_ - alpha.unsqueeze(-1)).abs_().mul_(-1 / (2 * RADIUS)).add_(1)
        weights[~near_hole] = 0.5
    pairs = _unsafe_standard_stable(
        alpha_, beta_, aux_uniform_, aux_exponential_, coords=coords
    )
    return (pairs * weights).sum(-1)


[docs]class Stable(TorchDistribution): r""" Levy :math:`\alpha`-stable distribution. See [1] for a review. This uses Nolan's parametrization [2] of the ``loc`` parameter, which is required for continuity and differentiability. This corresponds to the notation :math:`S^0_\alpha(\beta,\sigma,\mu_0)` of [1], where :math:`\alpha` = stability, :math:`\beta` = skew, :math:`\sigma` = scale, and :math:`\mu_0` = loc. To instead use the S parameterization as in scipy, pass ``coords="S"``, but BEWARE this is discontinuous at ``stability=1`` and has poor geometry for inference. This implements a reparametrized sampler :meth:`rsample` , and a relatively expensive :meth:`log_prob` calculation by numerical integration which makes inference slow (compared to other distributions) , but with better convergence properties especially for :math:`\alpha`-stable distributions that are skewed (see the ``skew`` parameter below). Faster inference can be performed using either likelihood-free algorithms such as :class:`~pyro.infer.energy_distance.EnergyDistance`, or reparameterization via the :func:`~pyro.poutine.handlers.reparam` handler with one of the reparameterizers :class:`~pyro.infer.reparam.stable.LatentStableReparam` , :class:`~pyro.infer.reparam.stable.SymmetricStableReparam` , or :class:`~pyro.infer.reparam.stable.StableReparam` e.g.:: with poutine.reparam(config={"x": StableReparam()}): pyro.sample("x", Stable(stability, skew, scale, loc)) or simply wrap in :class:`~pyro.infer.reparam.strategies.MinimalReparam` or :class:`~pyro.infer.reparam.strategies.AutoReparam` , e.g.:: @MinimalReparam() def model(): ... [1] S. Borak, W. Hardle, R. Weron (2005). Stable distributions. https://edoc.hu-berlin.de/bitstream/handle/18452/4526/8.pdf [2] J.P. Nolan (1997). Numerical calculation of stable densities and distribution functions. [3] Rafal Weron (1996). On the Chambers-Mallows-Stuck Method for Simulating Skewed Stable Random Variables. [4] J.P. Nolan (2017). Stable Distributions: Models for Heavy Tailed Data. https://edspace.american.edu/jpnolan/wp-content/uploads/sites/1720/2020/09/Chap1.pdf :param Tensor stability: Levy stability parameter :math:`\alpha\in(0,2]` . :param Tensor skew: Skewness :math:`\beta\in[-1,1]` . :param Tensor scale: Scale :math:`\sigma > 0` . Defaults to 1. :param Tensor loc: Location :math:`\mu_0` when using Nolan's S0 parametrization [2], or :math:`\mu` when using the S parameterization. Defaults to 0. :param str coords: Either "S0" (default) to use Nolan's continuous S0 parametrization, or "S" to use the discontinuous parameterization. """ has_rsample = True arg_constraints = { "stability": constraints.interval(0, 2), # half-open (0, 2] "skew": constraints.interval(-1, 1), # closed [-1, 1] "scale": constraints.positive, "loc": constraints.real, } support = constraints.real def __init__( self, stability, skew, scale=1.0, loc=0.0, coords="S0", validate_args=None ): assert coords in ("S", "S0"), coords self.stability, self.skew, self.scale, self.loc = broadcast_all( stability, skew, scale, loc ) self.coords = coords super().__init__(self.loc.shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Stable, _instance) batch_shape = torch.Size(batch_shape) for name in self.arg_constraints: setattr(new, name, getattr(self, name).expand(batch_shape)) new.coords = self.coords super(Stable, new).__init__(batch_shape, validate_args=False) new._validate_args = self._validate_args return new
[docs] def log_prob(self, value): r"""Implemented by numerical integration that is based on the algorithm proposed by Chambers, Mallows and Stuck (CMS) for simulating the Levy :math:`\alpha`-stable distribution. The CMS algorithm involves a nonlinear transformation of two independent random variables into one stable random variable. The first random variable is uniformly distributed while the second is exponentially distributed. The numerical integration is performed over the first uniformly distributed random variable. """ if self._validate_args: self._validate_sample(value) # Undo shift and scale value = (value - self.loc) / self.scale value_dtype = value.dtype # Use double precision math alpha = self.stability.double() beta = self.skew.double() value = value.double() alpha, beta, value = broadcast_all(alpha, beta, value) log_prob = _stable_log_prob(alpha, beta, value, self.coords) return log_prob.to(dtype=value_dtype) - self.scale.log()
[docs] def rsample(self, sample_shape=torch.Size()): # Draw parameter-free noise. with torch.no_grad(): shape = self._extended_shape(sample_shape) new_empty = self.stability.new_empty aux_uniform = new_empty(shape).uniform_(-math.pi / 2, math.pi / 2) aux_exponential = new_empty(shape).exponential_() # Differentiably transform. x = _standard_stable( self.stability, self.skew, aux_uniform, aux_exponential, coords=self.coords ) return self.loc + self.scale * x
@property def mean(self): result = self.loc if self.coords == "S0": result = ( result - self.scale * self.skew * (math.pi / 2 * self.stability).tan() ) return result.masked_fill(self.stability <= 1, math.nan) @property def variance(self): var = self.scale * self.scale return var.mul(2).masked_fill(self.stability < 2, math.inf)
[docs]class StableWithLogProb(Stable): r""" Same as :class:`Stable` but will not undergo reparameterization by :class:`~pyro.infer.reparam.strategies.MinimalReparam` and will fail reparametrization by :class:`~pyro.infer.reparam.stable.LatentStableReparam` , :class:`~pyro.infer.reparam.stable.SymmetricStableReparam` , or :class:`~pyro.infer.reparam.stable.StableReparam`. """