# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import warnings
from math import pi
import torch
from torch import broadcast_shapes
from torch.distributions import Uniform
from pyro.distributions import constraints
from .torch_distribution import TorchDistribution
[docs]class SineSkewed(TorchDistribution):
"""Sine Skewing [1] is a procedure for producing a distribution that breaks pointwise symmetry on a torus
distribution. The new distribution is called the Sine Skewed X distribution, where X is the name of the (symmetric)
base distribution.
Torus distributions are distributions with support on products of circles
(i.e., ⨂^d S^1 where S^1=[-pi,pi) ). So, a 0-torus is a point, the 1-torus is a circle,
and the 2-torus is commonly associated with the donut shape.
The Sine Skewed X distribution is parameterized by a weight parameter for each dimension of the event of X.
For example with a von Mises distribution over a circle (1-torus), the Sine Skewed von Mises Distribution has one
skew parameter. The skewness parameters can be inferred using :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`.
For example, the following will produce a uniform prior over skewness for the 2-torus,::
def model(obs):
# Sine priors
phi_loc = pyro.sample('phi_loc', VonMises(pi, 2.))
psi_loc = pyro.sample('psi_loc', VonMises(-pi / 2, 2.))
phi_conc = pyro.sample('phi_conc', Beta(halpha_phi, beta_prec_phi - halpha_phi))
psi_conc = pyro.sample('psi_conc', Beta(halpha_psi, beta_prec_psi - halpha_psi))
corr_scale = pyro.sample('corr_scale', Beta(2., 5.))
# SS prior
skew_phi = pyro.sample('skew_phi', Uniform(-1., 1.))
psi_bound = 1 - skew_phi.abs()
skew_psi = pyro.sample('skew_psi', Uniform(-1., 1.))
skewness = torch.stack((skew_phi, psi_bound * skew_psi), dim=-1)
assert skewness.shape == (num_mix_comp, 2)
with pyro.plate('obs_plate'):
sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc,
phi_concentration=1000 * phi_conc,
psi_concentration=1000 * psi_conc,
weighted_correlation=corr_scale)
return pyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs)
To ensure the skewing does not alter the normalization constant of the (Sine Bivaraite von Mises) base
distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of
skewness to be less than or equal to one.
So for the above snippet it must hold that::
skew_phi.abs()+skew_psi.abs() <= 1
We handle this in the prior by computing psi_bound and use it to scale skew_psi.
We do **not** use psi_bound as::
skew_psi = pyro.sample('skew_psi', Uniform(-psi_bound, psi_bound))
as it would make the support for the Uniform distribution dynamic.
In the context of :class:`~pyro.infer.SVI`, this distribution can freely be used as a likelihood, but use as
latent variables it will lead to slow inference for 2 and higher dim toruses. This is because the base_dist
cannot be reparameterized.
.. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d,).
.. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event
must be less than or equal to one. See eq. 2.1 in [1].
** References: **
1. Sine-skewed toroidal distributions and their application in protein bioinformatics
Ameijeiras-Alonso, J., Ley, C. (2019)
:param torch.distributions.Distribution base_dist: base density on a d-dimensional torus. Supported base
distributions include: 1D :class:`~pyro.distributions.VonMises`,
:class:`~pyro.distributions.SineBivariateVonMises`, 1D :class:`~pyro.distributions.ProjectedNormal`, and
:class:`~pyro.distributions.Uniform` (-pi, pi).
:param torch.tensor skewness: skewness of the distribution.
"""
arg_constraints = {
"skewness": constraints.independent(constraints.interval(-1.0, 1.0), 1)
}
support = constraints.independent(constraints.real, 1)
def __init__(self, base_dist: TorchDistribution, skewness, validate_args=None):
assert (
base_dist.event_shape == skewness.shape[-1:]
), "Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`."
if (skewness.abs().sum(-1) > 1.0).any():
warnings.warn("Total skewness weight shouldn't exceed one.", UserWarning)
batch_shape = broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1])
event_shape = skewness.shape[-1:]
self.skewness = skewness.broadcast_to(batch_shape + event_shape)
self.base_dist = base_dist.expand(batch_shape)
super().__init__(batch_shape, event_shape, validate_args=validate_args)
if self._validate_args and base_dist.mean.device != skewness.device:
raise ValueError(
f"base_density: {base_dist.__class__.__name__} and SineSkewed "
f"must be on same device."
)
def __repr__(self):
args_string = ", ".join(
[
"{}: {}".format(
p,
(
getattr(self, p)
if getattr(self, p).numel() == 1
else getattr(self, p).size()
),
)
for p in self.arg_constraints.keys()
]
)
return (
self.__class__.__name__
+ "("
+ f"base_density: {str(self.base_dist)}, "
+ args_string
+ ")"
)
[docs] def sample(self, sample_shape=torch.Size()):
bd = self.base_dist
ys = bd.sample(sample_shape)
u = Uniform(0.0, self.skewness.new_ones(())).sample(
sample_shape + self.batch_shape
)
# Section 2.3 step 3 in [1]
mask = u <= 0.5 + 0.5 * (
self.skewness * torch.sin((ys - bd.mean) % (2 * pi))
).sum(-1)
mask = mask[..., None]
samples = (torch.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi
return samples
[docs] def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
# Eq. 2.1 in [1]
skew_prob = torch.log1p(
(self.skewness * torch.sin((value - self.base_dist.mean) % (2 * pi))).sum(
-1
)
)
return self.base_dist.log_prob(value) + skew_prob
[docs] def expand(self, batch_shape, _instance=None):
batch_shape = torch.Size(batch_shape)
new = self._get_checked_instance(SineSkewed, _instance)
base_dist = self.base_dist.expand(batch_shape)
new.base_dist = base_dist
new.skewness = self.skewness.expand(batch_shape + (-1,))
super(SineSkewed, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new