# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import math
import torch
from . import constraints
from .torch_distribution import TorchDistribution
[docs]class VonMises3D(TorchDistribution):
"""
Spherical von Mises distribution.
This implementation combines the direction parameter and concentration
parameter into a single combined parameter that contains both direction and
magnitude. The ``value`` arg is represented in cartesian coordinates: it
must be a normalized 3-vector that lies on the 2-sphere.
See :class:`~pyro.distributions.VonMises` for a 2D polar coordinate cousin
of this distribution. See :class:`~pyro.distributions.projected_normal` for
a qualitatively similar distribution but implementing more functionality.
Currently only :meth:`log_prob` is implemented.
:param torch.Tensor concentration: A combined location-and-concentration
vector. The direction of this vector is the location, and its
magnitude is the concentration.
"""
arg_constraints = {"concentration": constraints.real}
support = constraints.sphere
def __init__(self, concentration, validate_args=None):
if concentration.dim() < 1 or concentration.shape[-1] != 3:
raise ValueError(
"Expected concentration to have rightmost dim 3, actual shape = {}".format(
concentration.shape
)
)
self.concentration = concentration
batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] def log_prob(self, value):
if self._validate_args:
if value.dim() < 1 or value.shape[-1] != 3:
raise ValueError(
"Expected value to have rightmost dim 3, actual shape = {}".format(
value.shape
)
)
if not (torch.abs(value.norm(2, -1) - 1) < 1e-6).all():
raise ValueError("direction vectors are not normalized")
scale = self.concentration.norm(2, -1)
log_normalizer = scale.log() - scale.sinh().log() - math.log(4 * math.pi)
return (self.concentration * value).sum(-1) + log_normalizer
[docs] def expand(self, batch_shape):
try:
return super().expand(batch_shape)
except NotImplementedError:
validate_args = self.__dict__.get("_validate_args")
concentration = self.concentration.expand(torch.Size(batch_shape) + (3,))
return type(self)(concentration, validate_args=validate_args)