# Source code for pyro.distributions.von_mises

from __future__ import absolute_import, division, print_function

import math

import torch
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all

from pyro.distributions import TorchDistribution

def _eval_poly(y, coef):
coef = list(coef)
result = coef.pop()
while coef:
result = coef.pop() + y * result
return result

_COEF_SMALL = [1.0, 3.5156229, 3.0899424, 1.2067492, 0.2659732, 0.360768e-1, 0.45813e-2]
_COEF_LARGE = [0.39894228, 0.1328592e-1, 0.225319e-2, -0.157565e-2, 0.916281e-2,
-0.2057706e-1, 0.2635537e-1, -0.1647633e-1,  0.392377e-2]

def _log_modified_bessel_fn_0(x):
"""
Returns log(I0(x)) for x > 0.
"""
# compute small solution
y = (x / 3.75).pow(2)
small = _eval_poly(y, _COEF_SMALL).log()

# compute large solution
y = 3.75 / x
large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE).log()

mask = (x < 3.75)
result = large
return result

[docs]class VonMises(TorchDistribution):
"""
A circular von Mises distribution.

This implementation uses polar coordinates. The loc and value args
can be any real number (to facilitate unconstrained optimization), but are
interpreted as angles modulo 2 pi.

See :class:~pyro.distributions.VonMises3D for a 3D cartesian coordinate
cousin of this distribution.

Currently only :meth:log_prob is implemented.

:param torch.Tensor loc: an angle in radians.
:param torch.Tensor concentration: concentration parameter
"""
arg_constraints = {'loc': constraints.real, 'concentration': constraints.positive}
support = constraints.real

def __init__(self, loc, concentration, validate_args=None):
self.loc, self.concentration = broadcast_all(loc, concentration)
batch_shape = self.loc.shape
event_shape = torch.Size()
super(VonMises, self).__init__(batch_shape, event_shape, validate_args)

[docs]    def log_prob(self, value):
log_prob = self.concentration * torch.cos(value - self.loc)
log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn_0(self.concentration)
return log_prob

[docs]    def expand(self, batch_shape):
try:
return super(VonMises, self).expand(batch_shape)
except NotImplementedError:
validate_args = self.__dict__.get('_validate_args')
loc = self.loc.expand(batch_shape)
concentration = self.concentration.expand(batch_shape)
return type(self)(loc, concentration, validate_args=validate_args)