# Source code for pyro.distributions.lowrank_mvn

from __future__ import absolute_import, division, print_function

import math

import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property

from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import matrix_triangular_solve_compat

[docs]class LowRankMultivariateNormal(TorchDistribution):
"""
Low Rank Multivariate Normal distribution.

Implements fast computation for log probability of Multivariate Normal distribution
when the covariance matrix has the form::

covariance_matrix = W.T @ W + D.

Here D is a diagonal vector and W is a matrix of size M x N. The
computation will be beneficial when M << N.

:param torch.Tensor loc: Mean.
Must be a 1D or 2D tensor with the last dimension of size N.
:param torch.Tensor W_term: W term of covariance matrix.
Must be in 2 dimensional of size M x N.
:param torch.Tensor D_term: D term of covariance matrix.
Must be in 1 dimensional of size N.
:param float trace_term: A optional term to be added into Mahalabonis term
according to p(y) = N(y|loc, cov).exp(-1/2 * trace_term).
"""
arg_constraints = {"loc": constraints.real,
"covariance_matrix_D_term": constraints.positive,
"scale_tril": constraints.lower_triangular}
support = constraints.real
has_rsample = True

def __init__(self, loc, W_term, D_term, trace_term=None):
if loc.shape[-1] != D_term.shape[0]:
raise ValueError("Expected loc.shape == D_term.shape, but got {} vs {}".format(
loc.shape, D_term.shape))
if D_term.shape[0] != W_term.shape[1]:
raise ValueError("The dimension of D_term must match the second dimension of W_term.")
if D_term.dim() != 1 or W_term.dim() != 2 or loc.dim() > 2:
raise ValueError("D_term, W_term must be 1D, 2D tensors respectively and "
"loc must be a 1D or 2D tensor.")

self.loc = loc
self.covariance_matrix_D_term = D_term
self.covariance_matrix_W_term = W_term
self.trace_term = trace_term if trace_term is not None else 0

batch_shape, event_shape = loc.shape[:-1], loc.shape[-1:]
super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape)

@property
def mean(self):
return self.loc

@property
def variance(self):
return self.covariance_matrix_D_term + (self.covariance_matrix_W_term ** 2).sum(0)

[docs]    @lazy_property
def scale_tril(self):
# We use the following formula to increase the numerically computation stability
# when using Cholesky decomposition (see GPML section 3.4.3):
#     D + W.T @ W = D1/2 @ (I + D-1/2 @ W.T @ W @ D-1/2) @ D1/2
Dsqrt = self.covariance_matrix_D_term.sqrt()
A = self.covariance_matrix_W_term / Dsqrt
At_A = A.t().matmul(A)
N = A.shape[1]
Id = torch.eye(N, N, out=A.new_empty(N, N))
K = Id + At_A
L = K.potrf(upper=False)
return Dsqrt.unsqueeze(1) * L

[docs]    def rsample(self, sample_shape=torch.Size()):
white = self.loc.new_empty(sample_shape + self.loc.shape).normal_()
return self.loc + torch.matmul(white, self.scale_tril.t())

[docs]    def log_prob(self, value):
delta = value - self.loc
logdet, mahalanobis_squared = self._compute_logdet_and_mahalanobis(
self.covariance_matrix_D_term, self.covariance_matrix_W_term, delta, self.trace_term)
normalization_const = 0.5 * (self.event_shape[-1] * math.log(2 * math.pi) + logdet)
return -(normalization_const + 0.5 * mahalanobis_squared)

def _compute_logdet_and_mahalanobis(self, D, W, y, trace_term=0):
"""
Calculates log determinant and (squared) Mahalanobis term of covariance
matrix (D + Wt.W), where D is a diagonal matrix, based on the
"Woodbury matrix identity" and "matrix determinant lemma"::

inv(D + Wt.W) = inv(D) - inv(D).Wt.inv(I + W.inv(D).Wt).W.inv(D)
log|D + Wt.W| = log|Id + Wt.inv(D).W| + log|D|
"""
W_Dinv = W / D
M = W.shape[0]
Id = torch.eye(M, M, out=W.new_empty(M, M))
K = Id + W_Dinv.matmul(W.t())
L = K.potrf(upper=False)
if y.dim() == 1:
W_Dinv_y = W_Dinv.matmul(y)
elif y.dim() == 2:
W_Dinv_y = W_Dinv.matmul(y.t())
else:
raise NotImplementedError("SparseMultivariateNormal distribution does not support "
"computing log_prob for a tensor with more than 2 dimensionals.")
Linv_W_Dinv_y = matrix_triangular_solve_compat(W_Dinv_y, L, upper=False)
if y.dim() == 2:
Linv_W_Dinv_y = Linv_W_Dinv_y.t()

logdet = 2 * L.diag().log().sum() + D.log().sum()

mahalanobis1 = (y * y / D).sum(-1)
mahalanobis2 = (Linv_W_Dinv_y * Linv_W_Dinv_y).sum(-1)
mahalanobis_squared = mahalanobis1 - mahalanobis2 + trace_term

return logdet, mahalanobis_squared