# Source code for pyro.distributions.von_mises_3d

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import math

import torch
from torch.distributions import constraints

from pyro.distributions import TorchDistribution

[docs]class VonMises3D(TorchDistribution):
"""
Spherical von Mises distribution.

This implementation combines the direction parameter and concentration
parameter into a single combined parameter that contains both direction and
magnitude. The value arg is represented in cartesian coordinates: it
must be a normalized 3-vector that lies on the 2-sphere.

See :class:~pyro.distributions.VonMises for a 2D polar coordinate cousin
of this distribution.

Currently only :meth:log_prob is implemented.

:param torch.Tensor concentration: A combined location-and-concentration
vector. The direction of this vector is the location, and its
magnitude is the concentration.
"""
arg_constraints = {'concentration': constraints.real}
support = constraints.real  # TODO implement constraints.sphere or similar

def __init__(self, concentration, validate_args=None):
if concentration.dim() < 1 or concentration.shape[-1] != 3:
raise ValueError('Expected concentration to have rightmost dim 3, actual shape = {}'.format(
concentration.shape))
self.concentration = concentration
batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)

[docs]    def log_prob(self, value):
if self._validate_args:
if value.dim() < 1 or value.shape[-1] != 3:
raise ValueError('Expected value to have rightmost dim 3, actual shape = {}'.format(
value.shape))
if not (torch.abs(value.norm(2, -1) - 1) < 1e-6).all():
raise ValueError('direction vectors are not normalized')
scale = self.concentration.norm(2, -1)
log_normalizer = scale.log() - scale.sinh().log() - math.log(4 * math.pi)
return (self.concentration * value).sum(-1) + log_normalizer

[docs]    def expand(self, batch_shape):
try:
return super().expand(batch_shape)
except NotImplementedError:
validate_args = self.__dict__.get('_validate_args')
concentration = self.concentration.expand(torch.Size(batch_shape) + (3,))
return type(self)(concentration, validate_args=validate_args)