Source code for pyro.distributions.avf_mvn

from __future__ import absolute_import, division, print_function

import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.distributions import constraints

from pyro.distributions.torch import MultivariateNormal
from pyro.distributions.util import sum_leftmost


[docs]class AVFMultivariateNormal(MultivariateNormal): """Multivariate normal (Gaussian) distribution with transport equation inspired control variates (adaptive velocity fields). A distribution over vectors in which all the elements have a joint Gaussian density. :param torch.Tensor loc: D-dimensional mean vector. :param torch.Tensor scale_tril: Cholesky of Covariance matrix; D x D matrix. :param torch.Tensor control_var: 2 x L x D tensor that parameterizes the control variate; L is an arbitrary positive integer. This parameter needs to be learned (i.e. adapted) to achieve lower variance gradients. In a typical use case this parameter will be adapted concurrently with the `loc` and `scale_tril` that define the distribution. Example usage:: control_var = torch.tensor(0.1 * torch.ones(2, 1, D), requires_grad=True) opt_cv = torch.optim.Adam([control_var], lr=0.1, betas=(0.5, 0.999)) for _ in range(1000): d = AVFMultivariateNormal(loc, scale_tril, control_var) z = d.rsample() cost = torch.pow(z, 2.0).sum() cost.backward() opt_cv.step() opt_cv.zero_grad() """ arg_constraints = {"loc": constraints.real, "scale_tril": constraints.lower_triangular, "control_var": constraints.real} def __init__(self, loc, scale_tril, control_var): if loc.dim() != 1: raise ValueError("AVFMultivariateNormal loc must be 1-dimensional") if scale_tril.dim() != 2: raise ValueError("AVFMultivariateNormal scale_tril must be 2-dimensional") if control_var.dim() != 3 or control_var.size(0) != 2 or control_var.size(2) != loc.size(0): raise ValueError("control_var should be of size 2 x L x D, where D is the dimension of the location parameter loc") # noqa: E501 self.control_var = control_var super(AVFMultivariateNormal, self).__init__(loc, scale_tril=scale_tril)
[docs] def rsample(self, sample_shape=torch.Size()): return _AVFMVNSample.apply(self.loc, self.scale_tril, self.control_var, sample_shape + self.loc.shape)
class _AVFMVNSample(Function): @staticmethod def forward(ctx, loc, scale_tril, control_var, shape): white = torch.randn(shape, dtype=loc.dtype, device=loc.device) z = torch.matmul(white, scale_tril.t()) ctx.save_for_backward(scale_tril, control_var, white) return loc + z @staticmethod @once_differentiable def backward(ctx, grad_output): L, control_var, epsilon = ctx.saved_tensors B, C = control_var g = grad_output loc_grad = sum_leftmost(grad_output, -1) # compute the rep trick gradient epsilon_jb = epsilon.unsqueeze(-2) g_ja = g.unsqueeze(-1) diff_L_ab = sum_leftmost(g_ja * epsilon_jb, -2) # modulate the velocity fields with infinitesimal rotations, i.e. apply the control variate gL = torch.matmul(g, L) eps_gL_ab = sum_leftmost(gL.unsqueeze(-1) * epsilon.unsqueeze(-2), -2) xi_ab = eps_gL_ab - eps_gL_ab.t() BC_lab = B.unsqueeze(-1) * C.unsqueeze(-2) diff_L_ab += (xi_ab.unsqueeze(0) * BC_lab).sum(0) L_grad = torch.tril(diff_L_ab) # compute control_var grads diff_B = (L_grad.unsqueeze(0) * C.unsqueeze(-2) * xi_ab.unsqueeze(0)).sum(2) diff_C = (L_grad.t().unsqueeze(0) * B.unsqueeze(-2) * xi_ab.t().unsqueeze(0)).sum(2) diff_CV = torch.stack([diff_B, diff_C]) return loc_grad, L_grad, diff_CV, None