Source code for pyro.distributions.omt_mvn

from __future__ import absolute_import, division, print_function

import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.distributions import constraints

from pyro.distributions.torch import MultivariateNormal
from pyro.distributions.util import sum_leftmost


[docs]class OMTMultivariateNormal(MultivariateNormal): """Multivariate normal (Gaussian) distribution with OMT gradients w.r.t. both parameters. Note the gradient computation w.r.t. the Cholesky factor has cost O(D^3), although the resulting gradient variance is generally expected to be lower. A distribution over vectors in which all the elements have a joint Gaussian density. :param torch.Tensor loc: Mean. :param torch.Tensor scale_tril: Cholesky of Covariance matrix. """ arg_constraints = {"loc": constraints.real, "scale_tril": constraints.lower_triangular} def __init__(self, loc, scale_tril): assert(loc.dim() == 1), "OMTMultivariateNormal loc must be 1-dimensional" assert(scale_tril.dim() == 2), "OMTMultivariateNormal scale_tril must be 2-dimensional" covariance_matrix = torch.mm(scale_tril, scale_tril.t()) super(OMTMultivariateNormal, self).__init__(loc, covariance_matrix) self.scale_tril = scale_tril
[docs] def rsample(self, sample_shape=torch.Size()): return _OMTMVNSample.apply(self.loc, self.scale_tril, sample_shape + self.loc.shape)
class _OMTMVNSample(Function): @staticmethod def forward(ctx, loc, scale_tril, shape): white = loc.new_empty(shape).normal_() z = torch.matmul(white, scale_tril.t()) ctx.save_for_backward(z, white, scale_tril) return loc + z @staticmethod @once_differentiable def backward(ctx, grad_output): jitter = 1.0e-8 # do i really need this? z, epsilon, L = ctx.saved_tensors dim = L.shape[0] g = grad_output loc_grad = sum_leftmost(grad_output, -1) identity = torch.eye(dim, out=torch.tensor(g.new_empty(dim, dim))) R_inv = torch.trtrs(identity, L.t(), transpose=False, upper=True)[0] z_ja = z.unsqueeze(-1) g_R_inv = torch.matmul(g, R_inv).unsqueeze(-2) epsilon_jb = epsilon.unsqueeze(-2) g_ja = g.unsqueeze(-1) diff_L_ab = 0.5 * sum_leftmost(g_ja * epsilon_jb + g_R_inv * z_ja, -2) Sigma_inv = torch.mm(R_inv, R_inv.t()) V, D, _ = torch.svd(Sigma_inv + jitter) D_outer = D.unsqueeze(-1) + D.unsqueeze(0) expand_tuple = tuple([-1] * (z.dim() - 1) + [dim, dim]) z_tilde = identity * torch.matmul(z, V).unsqueeze(-1).expand(*expand_tuple) g_tilde = identity * torch.matmul(g, V).unsqueeze(-1).expand(*expand_tuple) Y = sum_leftmost(torch.matmul(z_tilde, torch.matmul(1.0 / D_outer, g_tilde)), -2) Y = torch.mm(V, torch.mm(Y, V.t())) Y = Y + Y.t() Tr_xi_Y = torch.mm(torch.mm(Sigma_inv, Y), R_inv) - torch.mm(Y, torch.mm(Sigma_inv, R_inv)) diff_L_ab += 0.5 * Tr_xi_Y L_grad = torch.tril(diff_L_ab) return loc_grad, L_grad, None