Source code for pyro.distributions.multivariate_studentt

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import math

import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property

from pyro.distributions.torch import Chi2
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape


[docs]class MultivariateStudentT(TorchDistribution): """ Creates a multivariate Student's t-distribution parameterized by degree of freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale_tril`. :param ~torch.Tensor df: degrees of freedom :param ~torch.Tensor loc: mean of the distribution :param ~torch.Tensor scale_tril: scale of the distribution, which is a lower triangular matrix with positive diagonal entries """ arg_constraints = {'df': constraints.positive, 'loc': constraints.real_vector, 'scale_tril': constraints.lower_cholesky} support = constraints.real_vector has_rsample = True def __init__(self, df, loc, scale_tril, validate_args=None): dim = loc.size(-1) assert scale_tril.shape[-2:] == (dim, dim) if not isinstance(df, torch.Tensor): df = loc.new_tensor(df) batch_shape = broadcast_shape(df.shape, loc.shape[:-1], scale_tril.shape[:-2]) event_shape = torch.Size((dim,)) self.df = df.expand(batch_shape) self.loc = loc.expand(batch_shape + event_shape) self._unbroadcasted_scale_tril = scale_tril self._chi2 = Chi2(self.df) super().__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] @lazy_property def scale_tril(self): return self._unbroadcasted_scale_tril.expand( self._batch_shape + self._event_shape + self._event_shape)
[docs] @lazy_property def covariance_matrix(self): # NB: this is not covariance of this distribution; # the actual covariance is df / (df - 2) * covariance_matrix return (torch.matmul(self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.transpose(-1, -2)) .expand(self._batch_shape + self._event_shape + self._event_shape))
[docs] @lazy_property def precision_matrix(self): identity = torch.eye(self.loc.size(-1), device=self.loc.device, dtype=self.loc.dtype) return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand( self._batch_shape + self._event_shape + self._event_shape)
[docs] @staticmethod def infer_shapes(df, loc, scale_tril): event_shape = loc[-1:] batch_shape = broadcast_shape(df, loc[:-1], scale_tril[:-2]) return batch_shape, event_shape
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(MultivariateStudentT, _instance) batch_shape = torch.Size(batch_shape) loc_shape = batch_shape + self.event_shape scale_shape = loc_shape + self.event_shape new.df = self.df.expand(batch_shape) new.loc = self.loc.expand(loc_shape) new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril if 'scale_tril' in self.__dict__: new.scale_tril = self.scale_tril.expand(scale_shape) if 'covariance_matrix' in self.__dict__: new.covariance_matrix = self.covariance_matrix.expand(scale_shape) if 'precision_matrix' in self.__dict__: new.precision_matrix = self.precision_matrix.expand(scale_shape) new._chi2 = self._chi2.expand(batch_shape) super(MultivariateStudentT, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new
[docs] def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) X = torch.empty(shape, dtype=self.df.dtype, device=self.df.device).normal_() Z = self._chi2.rsample(sample_shape) Y = X * torch.rsqrt(Z / self.df).unsqueeze(-1) return self.loc + self.scale_tril.matmul(Y.unsqueeze(-1)).squeeze(-1)
[docs] def log_prob(self, value): if self._validate_args: self._validate_sample(value) n = self.loc.size(-1) y = (value - self.loc).unsqueeze(-1).triangular_solve(self.scale_tril, upper=False).solution.squeeze(-1) Z = (self.scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + 0.5 * n * self.df.log() + 0.5 * n * math.log(math.pi) + torch.lgamma(0.5 * self.df) - torch.lgamma(0.5 * (self.df + n))) return -0.5 * (self.df + n) * torch.log1p(y.pow(2).sum(-1) / self.df) - Z
@property def mean(self): m = self.loc.clone() m[self.df <= 1, :] = float('nan') return m @property def variance(self): m = self.scale_tril.pow(2).sum(-1) * (self.df / (self.df - 2)).unsqueeze(-1) m[(self.df <= 2) & (self.df > 1), :] = float('inf') m[self.df <= 1, :] = float('nan') return m