Source code for pyro.distributions.projected_normal

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

import torch

from pyro.ops.tensor_utils import safe_normalize

from . import constraints
from .torch_distribution import TorchDistribution


[docs]class ProjectedNormal(TorchDistribution): """ Projected isotropic normal distribution of arbitrary dimension. This distribution over directional data is qualitatively similar to the von Mises and von Mises-Fisher distributions, but permits tractable variational inference via reparametrized gradients. To use this distribution with autoguides, use ``poutine.reparam`` with a :class:`~pyro.infer.reparam.projected_normal.ProjectedNormalReparam` reparametrizer in the model, e.g.:: @poutine.reparam(config={"direction": ProjectedNormalReparam()}) def model(): direction = pyro.sample("direction", ProjectedNormal(torch.zeros(3))) ... or simply wrap in :class:`~pyro.infer.reparam.strategies.MinimalReparam` or :class:`~pyro.infer.reparam.strategies.AutoReparam` , e.g.:: @MinimalReparam() def model(): ... .. note:: This implements :meth:`log_prob` only for dimensions {2,3}. [1] D. Hernandez-Stumpfhauser, F.J. Breidt, M.J. van der Woerd (2017) "The General Projected Normal Distribution of Arbitrary Dimension: Modeling and Bayesian Inference" https://projecteuclid.org/euclid.ba/1453211962 :param torch.Tensor concentration: A combined location-and-concentration vector. The direction of this vector is the location, and its magnitude is the concentration. """ arg_constraints = {"concentration": constraints.real_vector} support = constraints.sphere has_rsample = True _log_prob_impls = {} # maps dim -> function(concentration, value) def __init__(self, concentration, *, validate_args=None): assert concentration.dim() >= 1 self.concentration = concentration batch_shape = concentration.shape[:-1] event_shape = concentration.shape[-1:] super().__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] @staticmethod def infer_shapes(concentration): batch_shape = concentration[:-1] event_shape = concentration[-1:] return batch_shape, event_shape
[docs] def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) new = self._get_checked_instance(ProjectedNormal, _instance) new.concentration = self.concentration.expand(batch_shape + (-1,)) super(ProjectedNormal, new).__init__( batch_shape, self.event_shape, validate_args=False ) new._validate_args = self.__dict__.get("_validate_args") return new
@property def mean(self): """ Note this is the mean in the sense of a centroid in the submanifold that minimizes expected squared geodesic distance. """ return safe_normalize(self.concentration) @property def mode(self): return safe_normalize(self.concentration)
[docs] def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) x = self.concentration.new_empty(shape).normal_() x = x + self.concentration x = safe_normalize(x) return x
[docs] def log_prob(self, value): if self._validate_args: event_shape = value.shape[-1:] if event_shape != self.event_shape: raise ValueError( f"Expected event shape {self.event_shape}, " f"but got {event_shape}" ) self._validate_sample(value) dim = int(self.concentration.size(-1)) try: impl = self._log_prob_impls[dim] except KeyError: msg = f"ProjectedNormal.log_prob() is not implemented for dim = {dim}." if value.requires_grad: # For latent variables but not observations. msg += " Consider using poutine.reparam with ProjectedNormalReparam." raise NotImplementedError(msg) return impl(self.concentration, value)
@classmethod def _register_log_prob(cls, dim, fn=None): if fn is None: return lambda fn: cls._register_log_prob(dim, fn) cls._log_prob_impls[dim] = fn return fn
def _dot(x, y): return (x[..., None, :] @ y[..., None])[..., 0, 0] def _safe_log(x): return x.clamp(min=torch.finfo(x.dtype).eps).log() @ProjectedNormal._register_log_prob(dim=2) def _log_prob_2(concentration, value): # We integrate along a ray, factorizing the integrand as a product of: # a truncated normal distribution over coordinate t parallel to the ray, and # a univariate normal distribution over coordinate r perpendicular to the ray. t = _dot(concentration, value) t2 = t.square() r2 = _dot(concentration, concentration) - t2 perp_part = r2.mul(-0.5) - 0.5 * math.log(2 * math.pi) # This is the log of a definite integral, computed by mathematica: # Integrate[x/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}] # = (t + Sqrt[2/Pi]/E^(t^2/2) + t Erf[t/Sqrt[2]])/2 # = (Sqrt[2/Pi]/E^(t^2/2) + t (1 + Erf[t/Sqrt[2]]))/2 # = (Sqrt[2/Pi]/E^(t^2/2) + t Erfc[-t/Sqrt[2]])/2 para_part = _safe_log( (t2.mul(-0.5).exp().mul((2 / math.pi) ** 0.5) + t * (t * -(0.5**0.5)).erfc()) / 2 ) return para_part + perp_part @ProjectedNormal._register_log_prob(dim=3) def _log_prob_3(concentration, value): # We integrate along a ray, factorizing the integrand as a product of: # a truncated normal distribution over coordinate t parallel to the ray, and # a bivariate normal distribution over coordinate r perpendicular to the ray. t = _dot(concentration, value) t2 = t.square() r2 = _dot(concentration, concentration) - t2 perp_part = r2.mul(-0.5) - math.log(2 * math.pi) # This is the log of a definite integral, computed by mathematica: # Integrate[x^2/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}] # = t/(E^(t^2/2) Sqrt[2 Pi]) + ((1 + t^2) (1 + Erf[t/Sqrt[2]]))/2 # = t/(E^(t^2/2) Sqrt[2 Pi]) + ((1 + t^2) Erfc[-t/Sqrt[2]])/2 para_part = _safe_log( t * t2.mul(-0.5).exp() / (2 * math.pi) ** 0.5 + (1 + t2) * (t * -(0.5**0.5)).erfc() / 2 ) return para_part + perp_part @ProjectedNormal._register_log_prob(dim=4) def _log_prob_4(concentration, value): # We integrate along a ray, factorizing the integrand as a product of: # a truncated normal distribution over coordinate t parallel to the ray, and # a bivariate normal distribution over coordinate r perpendicular to the ray. t = _dot(concentration, value) t2 = t.square() r2 = _dot(concentration, concentration) - t2 perp_part = r2.mul(-0.5) - 1.5 * math.log(2 * math.pi) # This is the log of a definite integral, computed by mathematica: # Integrate[x^3/(E^((x-t)^2/2) Sqrt[2 Pi]), {x, 0, Infinity}] # = (2 + t^2)/(E^(t^2/2) Sqrt[2 Pi]) + (t (3 + t^2) (1 + Erf[t/Sqrt[2]]))/2 # = (2 + t^2)/(E^(t^2/2) Sqrt[2 Pi]) + (t (3 + t^2) Erfc[-t/Sqrt[2]])/2 para_part = _safe_log( (2 + t2) * t2.mul(-0.5).exp() / (2 * math.pi) ** 0.5 + t * (3 + t2) * (t * -(0.5**0.5)).erfc() / 2 ) return para_part + perp_part