Source code for pyro.distributions.stable

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import math

import torch
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all

from pyro.distributions.torch_distribution import TorchDistribution


def _unsafe_standard_stable(alpha, beta, V, W, coords):
    # Implements a noisily reparametrized version of the sampler
    # Chambers-Mallows-Stuck method as corrected by Weron [1,3] and simplified
    # by Nolan [4]. This will fail if alpha is close to 1.

    # Differentiably transform noise via parameters.
    assert V.shape == W.shape
    inv_alpha = alpha.reciprocal()
    half_pi = math.pi / 2
    eps = torch.finfo(V.dtype).eps
    # make V belong to the open interval (-pi/2, pi/2)
    V = V.clamp(min=2 * eps - half_pi, max=half_pi - 2 * eps)
    ha = half_pi * alpha
    b = beta * ha.tan()
    # +/- `ha` term to keep the precision of alpha * (V + half_pi) when V ~ -half_pi
    v = b.atan() - ha + alpha * (V + half_pi)
    Z = (
        v.sin()
        / ((1 + b * b).rsqrt() * V.cos()).pow(inv_alpha)
        * ((v - V).cos().clamp(min=eps) / W).pow(inv_alpha - 1)
    )
    Z.data[Z.data != Z.data] = 0  # drop occasional NANs

    # Optionally convert to Nolan's parametrization S^0 where samples depend
    # continuously on (alpha,beta), allowing interpolation around the hole at
    # alpha=1.
    if coords == "S0":
        return Z - b
    elif coords == "S":
        return Z
    else:
        raise ValueError("Unknown coords: {}".format(coords))


RADIUS = 0.01


def _standard_stable(alpha, beta, aux_uniform, aux_exponential, coords):
    """
    Differentiably transform two random variables::

        aux_uniform ~ Uniform(-pi/2, pi/2)
        aux_exponential ~ Exponential(1)

    to a standard ``Stable(alpha, beta)`` random variable.
    """
    # Determine whether a hole workaround is needed.
    with torch.no_grad():
        hole = 1.0
        near_hole = (alpha - hole).abs() <= RADIUS
    if not torch._C._get_tracing_state() and not near_hole.any():
        return _unsafe_standard_stable(
            alpha, beta, aux_uniform, aux_exponential, coords=coords
        )
    if coords == "S":
        # S coords are discontinuous, so interpolate instead in S0 coords.
        Z = _standard_stable(alpha, beta, aux_uniform, aux_exponential, "S0")
        return torch.where(alpha == 1, Z, Z + beta * (math.pi / 2 * alpha).tan())

    # Avoid the hole at alpha=1 by interpolating between pairs
    # of points at hole-RADIUS and hole+RADIUS.
    aux_uniform_ = aux_uniform.unsqueeze(-1)
    aux_exponential_ = aux_exponential.unsqueeze(-1)
    beta_ = beta.unsqueeze(-1)
    alpha_ = alpha.unsqueeze(-1).expand(alpha.shape + (2,)).contiguous()
    with torch.no_grad():
        lower, upper = alpha_.unbind(-1)
        lower.data[near_hole] = hole - RADIUS
        upper.data[near_hole] = hole + RADIUS
        # We don't need to backprop through weights, since we've pretended
        # alpha_ is reparametrized, even though we've clamped some values.
        #               |a - a'|
        # weight = 1 - ----------
        #              2 * RADIUS
        weights = (alpha_ - alpha.unsqueeze(-1)).abs_().mul_(-1 / (2 * RADIUS)).add_(1)
        weights[~near_hole] = 0.5
    pairs = _unsafe_standard_stable(
        alpha_, beta_, aux_uniform_, aux_exponential_, coords=coords
    )
    return (pairs * weights).sum(-1)


[docs]class Stable(TorchDistribution): r""" Levy :math:`\alpha`-stable distribution. See [1] for a review. This uses Nolan's parametrization [2] of the ``loc`` parameter, which is required for continuity and differentiability. This corresponds to the notation :math:`S^0_\alpha(\beta,\sigma,\mu_0)` of [1], where :math:`\alpha` = stability, :math:`\beta` = skew, :math:`\sigma` = scale, and :math:`\mu_0` = loc. To instead use the S parameterization as in scipy, pass ``coords="S"``, but BEWARE this is discontinuous at ``stability=1`` and has poor geometry for inference. This implements a reparametrized sampler :meth:`rsample` , but does not implement :meth:`log_prob` . Inference can be performed using either likelihood-free algorithms such as :class:`~pyro.infer.energy_distance.EnergyDistance`, or reparameterization via the :func:`~pyro.poutine.handlers.reparam` handler with one of the reparameterizers :class:`~pyro.infer.reparam.stable.LatentStableReparam` , :class:`~pyro.infer.reparam.stable.SymmetricStableReparam` , or :class:`~pyro.infer.reparam.stable.StableReparam` e.g.:: with poutine.reparam(config={"x": StableReparam()}): pyro.sample("x", Stable(stability, skew, scale, loc)) or simply wrap in :class:`~pyro.infer.reparam.strategies.MinimalReparam` or :class:`~pyro.infer.reparam.strategies.AutoReparam` , e.g.:: @MinimalReparam() def model(): ... [1] S. Borak, W. Hardle, R. Weron (2005). Stable distributions. https://edoc.hu-berlin.de/bitstream/handle/18452/4526/8.pdf [2] J.P. Nolan (1997). Numerical calculation of stable densities and distribution functions. [3] Rafal Weron (1996). On the Chambers-Mallows-Stuck Method for Simulating Skewed Stable Random Variables. [4] J.P. Nolan (2017). Stable Distributions: Models for Heavy Tailed Data. https://edspace.american.edu/jpnolan/wp-content/uploads/sites/1720/2020/09/Chap1.pdf :param Tensor stability: Levy stability parameter :math:`\alpha\in(0,2]` . :param Tensor skew: Skewness :math:`\beta\in[-1,1]` . :param Tensor scale: Scale :math:`\sigma > 0` . Defaults to 1. :param Tensor loc: Location :math:`\mu_0` when using Nolan's S0 parametrization [2], or :math:`\mu` when using the S parameterization. Defaults to 0. :param str coords: Either "S0" (default) to use Nolan's continuous S0 parametrization, or "S" to use the discontinuous parameterization. """ has_rsample = True arg_constraints = { "stability": constraints.interval(0, 2), # half-open (0, 2] "skew": constraints.interval(-1, 1), # closed [-1, 1] "scale": constraints.positive, "loc": constraints.real, } support = constraints.real def __init__( self, stability, skew, scale=1.0, loc=0.0, coords="S0", validate_args=None ): assert coords in ("S", "S0"), coords self.stability, self.skew, self.scale, self.loc = broadcast_all( stability, skew, scale, loc ) self.coords = coords super().__init__(self.loc.shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Stable, _instance) batch_shape = torch.Size(batch_shape) for name in self.arg_constraints: setattr(new, name, getattr(self, name).expand(batch_shape)) new.coords = self.coords super(Stable, new).__init__(batch_shape, validate_args=False) new._validate_args = self._validate_args return new
[docs] def log_prob(self, value): raise NotImplementedError("Stable.log_prob() is not implemented")
[docs] def rsample(self, sample_shape=torch.Size()): # Draw parameter-free noise. with torch.no_grad(): shape = self._extended_shape(sample_shape) new_empty = self.stability.new_empty aux_uniform = new_empty(shape).uniform_(-math.pi / 2, math.pi / 2) aux_exponential = new_empty(shape).exponential_() # Differentiably transform. x = _standard_stable( self.stability, self.skew, aux_uniform, aux_exponential, coords=self.coords ) return self.loc + self.scale * x
@property def mean(self): result = self.loc if self.coords == "S0": result = ( result - self.scale * self.skew * (math.pi / 2 * self.stability).tan() ) return result.masked_fill(self.stability <= 1, math.nan) @property def variance(self): var = self.scale * self.scale return var.mul(2).masked_fill(self.stability < 2, math.inf)