import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.distributions import constraints

from pyro.distributions.torch import MultivariateNormal
from pyro.distributions.util import eye_like, sum_leftmost

[docs]class OMTMultivariateNormal(MultivariateNormal): """Multivariate normal (Gaussian) distribution with OMT gradients w.r.t. both parameters. Note the gradient computation w.r.t. the Cholesky factor has cost O(D^3), although the resulting gradient variance is generally expected to be lower. A distribution over vectors in which all the elements have a joint Gaussian density. :param torch.Tensor loc: Mean. :param torch.Tensor scale_tril: Cholesky of Covariance matrix. """ arg_constraints = {"loc": constraints.real, "scale_tril": constraints.lower_triangular} def __init__(self, loc, scale_tril): if loc.dim() != 1: raise ValueError("OMTMultivariateNormal loc must be 1-dimensional") if scale_tril.dim() != 2: raise ValueError("OMTMultivariateNormal scale_tril must be 2-dimensional") super(OMTMultivariateNormal, self).__init__(loc, scale_tril=scale_tril)
[docs] def rsample(self, sample_shape=torch.Size()): return _OMTMVNSample.apply(self.loc, self.scale_tril, sample_shape + self.loc.shape)
class _OMTMVNSample(Function): @staticmethod def forward(ctx, loc, scale_tril, shape): white = torch.randn(shape, dtype=loc.dtype, device=loc.device) z = torch.matmul(white, scale_tril.t()) ctx.save_for_backward(z, white, scale_tril) return loc + z @staticmethod @once_differentiable def backward(ctx, grad_output): jitter = 1.0e-8 # do i really need this? z, epsilon, L = ctx.saved_tensors dim = L.shape[0] g = grad_output loc_grad = sum_leftmost(grad_output, -1) identity = eye_like(g, dim) R_inv = torch.triangular_solve(identity, L.t(), transpose=False, upper=True)[0] z_ja = z.unsqueeze(-1) g_R_inv = torch.matmul(g, R_inv).unsqueeze(-2) epsilon_jb = epsilon.unsqueeze(-2) g_ja = g.unsqueeze(-1) diff_L_ab = 0.5 * sum_leftmost(g_ja * epsilon_jb + g_R_inv * z_ja, -2) Sigma_inv =, R_inv.t()) V, D, _ = torch.svd(Sigma_inv + jitter) D_outer = D.unsqueeze(-1) + D.unsqueeze(0) expand_tuple = tuple([-1] * (z.dim() - 1) + [dim, dim]) z_tilde = identity * torch.matmul(z, V).unsqueeze(-1).expand(*expand_tuple) g_tilde = identity * torch.matmul(g, V).unsqueeze(-1).expand(*expand_tuple) Y = sum_leftmost(torch.matmul(z_tilde, torch.matmul(1.0 / D_outer, g_tilde)), -2) Y =,, V.t())) Y = Y + Y.t() Tr_xi_Y =, Y), R_inv) -,, R_inv)) diff_L_ab += 0.5 * Tr_xi_Y L_grad = torch.tril(diff_L_ab) return loc_grad, L_grad, None