# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import math
import torch
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all
from pyro.distributions.stable_log_prob import _stable_log_prob
from pyro.distributions.torch_distribution import TorchDistribution
def _unsafe_standard_stable(alpha, beta, V, W, coords):
# Implements a noisily reparametrized version of the sampler
# Chambers-Mallows-Stuck method as corrected by Weron [1,3] and simplified
# by Nolan [4]. This will fail if alpha is close to 1.
# Differentiably transform noise via parameters.
assert V.shape == W.shape
inv_alpha = alpha.reciprocal()
half_pi = math.pi / 2
eps = torch.finfo(V.dtype).eps
# make V belong to the open interval (-pi/2, pi/2)
V = V.clamp(min=2 * eps - half_pi, max=half_pi - 2 * eps)
ha = half_pi * alpha
b = beta * ha.tan()
# +/- `ha` term to keep the precision of alpha * (V + half_pi) when V ~ -half_pi
v = b.atan() - ha + alpha * (V + half_pi)
Z = (
v.sin()
/ ((1 + b * b).rsqrt() * V.cos()).pow(inv_alpha)
* ((v - V).cos().clamp(min=eps) / W).pow(inv_alpha - 1)
)
Z.data[Z.data != Z.data] = 0 # drop occasional NANs
# Optionally convert to Nolan's parametrization S^0 where samples depend
# continuously on (alpha,beta), allowing interpolation around the hole at
# alpha=1.
if coords == "S0":
return Z - b
elif coords == "S":
return Z
else:
raise ValueError("Unknown coords: {}".format(coords))
RADIUS = 0.01
def _standard_stable(alpha, beta, aux_uniform, aux_exponential, coords):
"""
Differentiably transform two random variables::
aux_uniform ~ Uniform(-pi/2, pi/2)
aux_exponential ~ Exponential(1)
to a standard ``Stable(alpha, beta)`` random variable.
"""
# Determine whether a hole workaround is needed.
with torch.no_grad():
hole = 1.0
near_hole = (alpha - hole).abs() <= RADIUS
if not torch._C._get_tracing_state() and not near_hole.any():
return _unsafe_standard_stable(
alpha, beta, aux_uniform, aux_exponential, coords=coords
)
if coords == "S":
# S coords are discontinuous, so interpolate instead in S0 coords.
Z = _standard_stable(alpha, beta, aux_uniform, aux_exponential, "S0")
return torch.where(alpha == 1, Z, Z + beta * (math.pi / 2 * alpha).tan())
# Avoid the hole at alpha=1 by interpolating between pairs
# of points at hole-RADIUS and hole+RADIUS.
aux_uniform_ = aux_uniform.unsqueeze(-1)
aux_exponential_ = aux_exponential.unsqueeze(-1)
beta_ = beta.unsqueeze(-1)
alpha_ = alpha.unsqueeze(-1).expand(alpha.shape + (2,)).contiguous()
with torch.no_grad():
lower, upper = alpha_.unbind(-1)
lower.data[near_hole] = hole - RADIUS
upper.data[near_hole] = hole + RADIUS
# We don't need to backprop through weights, since we've pretended
# alpha_ is reparametrized, even though we've clamped some values.
# |a - a'|
# weight = 1 - ----------
# 2 * RADIUS
weights = (alpha_ - alpha.unsqueeze(-1)).abs_().mul_(-1 / (2 * RADIUS)).add_(1)
weights[~near_hole] = 0.5
pairs = _unsafe_standard_stable(
alpha_, beta_, aux_uniform_, aux_exponential_, coords=coords
)
return (pairs * weights).sum(-1)
[docs]class Stable(TorchDistribution):
r"""
Levy :math:`\alpha`-stable distribution. See [1] for a review.
This uses Nolan's parametrization [2] of the ``loc`` parameter, which is
required for continuity and differentiability. This corresponds to the
notation :math:`S^0_\alpha(\beta,\sigma,\mu_0)` of [1], where
:math:`\alpha` = stability, :math:`\beta` = skew, :math:`\sigma` = scale,
and :math:`\mu_0` = loc. To instead use the S parameterization as in scipy,
pass ``coords="S"``, but BEWARE this is discontinuous at ``stability=1``
and has poor geometry for inference.
This implements a reparametrized sampler :meth:`rsample` , and a relatively
expensive :meth:`log_prob` calculation by numerical integration which makes
inference slow (compared to other distributions) , but with better
convergence properties especially for :math:`\alpha`-stable distributions
that are skewed (see the ``skew`` parameter below). Faster
inference can be performed using either likelihood-free algorithms such as
:class:`~pyro.infer.energy_distance.EnergyDistance`, or reparameterization
via the :func:`~pyro.poutine.handlers.reparam` handler with one of the
reparameterizers :class:`~pyro.infer.reparam.stable.LatentStableReparam` ,
:class:`~pyro.infer.reparam.stable.SymmetricStableReparam` , or
:class:`~pyro.infer.reparam.stable.StableReparam` e.g.::
with poutine.reparam(config={"x": StableReparam()}):
pyro.sample("x", Stable(stability, skew, scale, loc))
or simply wrap in :class:`~pyro.infer.reparam.strategies.MinimalReparam` or
:class:`~pyro.infer.reparam.strategies.AutoReparam` , e.g.::
@MinimalReparam()
def model():
...
[1] S. Borak, W. Hardle, R. Weron (2005).
Stable distributions.
https://edoc.hu-berlin.de/bitstream/handle/18452/4526/8.pdf
[2] J.P. Nolan (1997).
Numerical calculation of stable densities and distribution functions.
[3] Rafal Weron (1996).
On the Chambers-Mallows-Stuck Method for
Simulating Skewed Stable Random Variables.
[4] J.P. Nolan (2017).
Stable Distributions: Models for Heavy Tailed Data.
https://edspace.american.edu/jpnolan/wp-content/uploads/sites/1720/2020/09/Chap1.pdf
:param Tensor stability: Levy stability parameter :math:`\alpha\in(0,2]` .
:param Tensor skew: Skewness :math:`\beta\in[-1,1]` .
:param Tensor scale: Scale :math:`\sigma > 0` . Defaults to 1.
:param Tensor loc: Location :math:`\mu_0` when using Nolan's S0
parametrization [2], or :math:`\mu` when using the S parameterization.
Defaults to 0.
:param str coords: Either "S0" (default) to use Nolan's continuous S0
parametrization, or "S" to use the discontinuous parameterization.
"""
has_rsample = True
arg_constraints = {
"stability": constraints.interval(0, 2), # half-open (0, 2]
"skew": constraints.interval(-1, 1), # closed [-1, 1]
"scale": constraints.positive,
"loc": constraints.real,
}
support = constraints.real
def __init__(
self, stability, skew, scale=1.0, loc=0.0, coords="S0", validate_args=None
):
assert coords in ("S", "S0"), coords
self.stability, self.skew, self.scale, self.loc = broadcast_all(
stability, skew, scale, loc
)
self.coords = coords
super().__init__(self.loc.shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Stable, _instance)
batch_shape = torch.Size(batch_shape)
for name in self.arg_constraints:
setattr(new, name, getattr(self, name).expand(batch_shape))
new.coords = self.coords
super(Stable, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
[docs] def log_prob(self, value):
r"""Implemented by numerical integration that is based on the algorithm
proposed by Chambers, Mallows and Stuck (CMS) for simulating the
Levy :math:`\alpha`-stable distribution. The CMS algorithm involves a
nonlinear transformation of two independent random variables into
one stable random variable. The first random variable is uniformly
distributed while the second is exponentially distributed. The numerical
integration is performed over the first uniformly distributed random
variable.
"""
if self._validate_args:
self._validate_sample(value)
# Undo shift and scale
value = (value - self.loc) / self.scale
value_dtype = value.dtype
# Use double precision math
alpha = self.stability.double()
beta = self.skew.double()
value = value.double()
alpha, beta, value = broadcast_all(alpha, beta, value)
log_prob = _stable_log_prob(alpha, beta, value, self.coords)
return log_prob.to(dtype=value_dtype) - self.scale.log()
[docs] def rsample(self, sample_shape=torch.Size()):
# Draw parameter-free noise.
with torch.no_grad():
shape = self._extended_shape(sample_shape)
new_empty = self.stability.new_empty
aux_uniform = new_empty(shape).uniform_(-math.pi / 2, math.pi / 2)
aux_exponential = new_empty(shape).exponential_()
# Differentiably transform.
x = _standard_stable(
self.stability, self.skew, aux_uniform, aux_exponential, coords=self.coords
)
return self.loc + self.scale * x
@property
def mean(self):
result = self.loc
if self.coords == "S0":
result = (
result - self.scale * self.skew * (math.pi / 2 * self.stability).tan()
)
return result.masked_fill(self.stability <= 1, math.nan)
@property
def variance(self):
var = self.scale * self.scale
return var.mul(2).masked_fill(self.stability < 2, math.inf)
[docs]class StableWithLogProb(Stable):
r"""
Same as :class:`Stable` but will not undergo reparameterization by
:class:`~pyro.infer.reparam.strategies.MinimalReparam` and will fail
reparametrization by
:class:`~pyro.infer.reparam.stable.LatentStableReparam` ,
:class:`~pyro.infer.reparam.stable.SymmetricStableReparam` , or
:class:`~pyro.infer.reparam.stable.StableReparam`.
"""