Source code for pyro.distributions.asymmetriclaplace

# Copyright Contributors to the Pyro project.

import math

import torch
from torch.distributions import constraints

from .torch_distribution import TorchDistribution

[docs]class AsymmetricLaplace(TorchDistribution):
"""
Asymmetric version of the :class:~pyro.distributions.Laplace
distribution.

To the left of loc this acts like an
-Exponential(1/(asymmetry*scale)); to the right of loc this acts
like an Exponential(asymmetry/scale). The density is continuous so the
left and right densities at loc agree.

:param loc: Location parameter, i.e. the mode.
:param scale: Scale parameter = geometric mean of left and right scales.
:param asymmetry: Square of ratio of left to right scales.
"""

arg_constraints = {
"loc": constraints.real,
"scale": constraints.positive,
"asymmetry": constraints.positive,
}
support = constraints.real
has_rsample = True

def __init__(self, loc, scale, asymmetry, *, validate_args=None):
self.loc, self.scale, self.asymmetry = broadcast_all(loc, scale, asymmetry)
super().__init__(self.loc.shape, validate_args=validate_args)

[docs]    @lazy_property
def left_scale(self):
return self.scale * self.asymmetry

[docs]    @lazy_property
def right_scale(self):
return self.scale / self.asymmetry

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(AsymmetricLaplace, _instance)
batch_shape = torch.Size(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
new.asymmetry = self.asymmetry.expand(batch_shape)
super(AsymmetricLaplace, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

[docs]    def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
z = value - self.loc
z = -z.abs() / torch.where(z < 0, self.left_scale, self.right_scale)
return z - (self.left_scale + self.right_scale).log()

[docs]    def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
u, v = self.loc.new_empty((2,) + shape).exponential_()
return self.loc - self.left_scale * u + self.right_scale * v

@property
def mean(self):
total_scale = self.left_scale + self.right_scale
return self.loc + (self.right_scale ** 2 - self.left_scale ** 2) / total_scale

@property
def variance(self):
left = self.left_scale
right = self.right_scale
total = left + right
p = left / total
q = right / total
return p * left ** 2 + q * right ** 2 + p * q * total ** 2

[docs]class SoftAsymmetricLaplace(TorchDistribution):
"""
Soft asymmetric version of the :class:~pyro.distributions.Laplace
distribution.

This has a smooth (infinitely differentiable) density with two asymmetric
asymptotically exponential tails, one on the left and one on the right. In
the limit of softness → 0, this converges in distribution to the
:class:AsymmetricLaplace distribution.

This is equivalent to the sum of three random variables z - u + v where::

z ~ Normal(loc, scale * softness)
u ~ Exponential(1 / (scale * asymmetry))
v ~ Exponential(asymetry / scale)

This is also equivalent the sum of two random variables z + a where::

z ~ Normal(loc, scale * softness)
a ~ AsymmetricLaplace(0, scale, asymmetry)

:param loc: Location parameter, i.e. the mode.
:param scale: Scale parameter = geometric mean of left and right scales.
:param asymmetry: Square of ratio of left to right scales. Defaults to 1.
:param softness: Scale parameter of the Gaussian smoother. Defaults to 1.
"""

arg_constraints = {
"loc": constraints.real,
"scale": constraints.positive,
"asymmetry": constraints.positive,
"softness": constraints.positive,
}
support = constraints.real
has_rsample = True

def __init__(self, loc, scale, asymmetry=1.0, softness=1.0, *, validate_args=None):
self.loc, self.scale, self.asymmetry, self.softness = broadcast_all(
loc,
scale,
asymmetry,
softness,
)
super().__init__(self.loc.shape, validate_args=validate_args)

[docs]    @lazy_property
def left_scale(self):
return self.scale * self.asymmetry

[docs]    @lazy_property
def right_scale(self):
return self.scale / self.asymmetry

[docs]    @lazy_property
def soft_scale(self):
return self.scale * self.softness

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(AsymmetricLaplace, _instance)
batch_shape = torch.Size(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
new.asymmetry = self.asymmetry.expand(batch_shape)
new.softness = self.softness.expand(batch_shape)
super(AsymmetricLaplace, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

[docs]    def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

# Standardize.
x = (value - self.loc) / self.scale
L = self.asymmetry
R = self.asymmetry.reciprocal()
S = self.softness
SS = S * S
S2 = S * math.sqrt(2)
Lx = L * x
Rx = R * x

# This is the sum of two integrals which are proportional to:
# left = Integrate[e^(-t/L - ((x+t)/S)^2/2)/sqrt(2 pi)/S, {t,0,Infinity}]
#      = 1/2 e^((2 L x + S^2)/(2 L^2)) erfc((L x + S^2)/(sqrt(2) L S))
# right = Integrate[e^(-t/R - ((x-t)/S)^2/2)/sqrt(2 pi)/S, {t,0,Infinity}]
#       = 1/2 e^((S^2 - 2 R x)/(2 R^2)) erfc((S^2 - R x)/(sqrt(2) R S))
return (
math.log(0.5)
(SS / 2 + Lx) / L ** 2 + _logerfc((SS + Lx) / (L * S2)),
(SS / 2 - Rx) / R ** 2 + _logerfc((SS - Rx) / (R * S2)),
)
- (L + R).log()
- self.scale.log()
)

[docs]    def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
z = self.loc.new_empty(shape).normal_()
u, v = self.loc.new_empty((2,) + shape).exponential_()
return (
self.loc + self.soft_scale * z - self.left_scale * u + self.right_scale * v
)

@property
def mean(self):
total_scale = self.left_scale + self.right_scale
return self.loc + (self.right_scale ** 2 - self.left_scale ** 2) / total_scale

@property
def variance(self):
left = self.left_scale
right = self.right_scale
total = left + right
p = left / total
q = right / total
return (
p * left ** 2 + q * right ** 2 + p * q * total ** 2 + self.soft_scale ** 2
)

def _logerfc(x):
try:
# Requires https://github.com/pytorch/pytorch/issues/31945