# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import math
import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property
from pyro.distributions.torch import Chi2
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape
[docs]class MultivariateStudentT(TorchDistribution):
"""
Creates a multivariate Student's t-distribution parameterized by degree of
freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale_tril`.
:param ~torch.Tensor df: degrees of freedom
:param ~torch.Tensor loc: mean of the distribution
:param ~torch.Tensor scale_tril: scale of the distribution, which is
a lower triangular matrix with positive diagonal entries
"""
arg_constraints = {
"df": constraints.positive,
"loc": constraints.real_vector,
"scale_tril": constraints.lower_cholesky,
}
support = constraints.real_vector
has_rsample = True
def __init__(self, df, loc, scale_tril, validate_args=None):
dim = loc.size(-1)
assert scale_tril.shape[-2:] == (dim, dim)
if not isinstance(df, torch.Tensor):
df = loc.new_tensor(df)
batch_shape = broadcast_shape(df.shape, loc.shape[:-1], scale_tril.shape[:-2])
event_shape = torch.Size((dim,))
self.df = df.expand(batch_shape)
self.loc = loc.expand(batch_shape + event_shape)
self._unbroadcasted_scale_tril = scale_tril
self._chi2 = Chi2(self.df)
super().__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] @lazy_property
def scale_tril(self):
return self._unbroadcasted_scale_tril.expand(
self._batch_shape + self._event_shape + self._event_shape
)
[docs] @lazy_property
def covariance_matrix(self):
# NB: this is not covariance of this distribution;
# the actual covariance is df / (df - 2) * covariance_matrix
return torch.matmul(
self._unbroadcasted_scale_tril,
self._unbroadcasted_scale_tril.transpose(-1, -2),
).expand(self._batch_shape + self._event_shape + self._event_shape)
[docs] @lazy_property
def precision_matrix(self):
identity = torch.eye(
self.loc.size(-1), device=self.loc.device, dtype=self.loc.dtype
)
return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(
self._batch_shape + self._event_shape + self._event_shape
)
[docs] @staticmethod
def infer_shapes(df, loc, scale_tril):
event_shape = loc[-1:]
batch_shape = broadcast_shape(df, loc[:-1], scale_tril[:-2])
return batch_shape, event_shape
[docs] def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(MultivariateStudentT, _instance)
batch_shape = torch.Size(batch_shape)
loc_shape = batch_shape + self.event_shape
scale_shape = loc_shape + self.event_shape
new.df = self.df.expand(batch_shape)
new.loc = self.loc.expand(loc_shape)
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
if "scale_tril" in self.__dict__:
new.scale_tril = self.scale_tril.expand(scale_shape)
if "covariance_matrix" in self.__dict__:
new.covariance_matrix = self.covariance_matrix.expand(scale_shape)
if "precision_matrix" in self.__dict__:
new.precision_matrix = self.precision_matrix.expand(scale_shape)
new._chi2 = self._chi2.expand(batch_shape)
super(MultivariateStudentT, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
[docs] def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
X = torch.empty(shape, dtype=self.df.dtype, device=self.df.device).normal_()
Z = self._chi2.rsample(sample_shape)
Y = X * torch.rsqrt(Z / self.df).unsqueeze(-1)
return self.loc + self.scale_tril.matmul(Y.unsqueeze(-1)).squeeze(-1)
[docs] def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
n = self.loc.size(-1)
y = (
(value - self.loc)
.unsqueeze(-1)
.triangular_solve(self.scale_tril, upper=False)
.solution.squeeze(-1)
)
Z = (
self.scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
+ 0.5 * n * self.df.log()
+ 0.5 * n * math.log(math.pi)
+ torch.lgamma(0.5 * self.df)
- torch.lgamma(0.5 * (self.df + n))
)
return -0.5 * (self.df + n) * torch.log1p(y.pow(2).sum(-1) / self.df) - Z
@property
def mean(self):
m = self.loc.clone()
m[self.df <= 1, :] = float("nan")
return m
@property
def variance(self):
m = self.scale_tril.pow(2).sum(-1) * (self.df / (self.df - 2)).unsqueeze(-1)
m[(self.df <= 2) & (self.df > 1), :] = float("inf")
m[self.df <= 1, :] = float("nan")
return m