# Source code for pyro.distributions.avf_mvn

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.distributions import constraints

from pyro.distributions.torch import MultivariateNormal
from pyro.distributions.util import sum_leftmost

[docs]class AVFMultivariateNormal(MultivariateNormal):
"""Multivariate normal (Gaussian) distribution with transport equation inspired control
variates (adaptive velocity fields).

A distribution over vectors in which all the elements have a joint Gaussian density.

:param torch.Tensor loc: D-dimensional mean vector.
:param torch.Tensor scale_tril: Cholesky of Covariance matrix; D x D matrix.
:param torch.Tensor control_var: 2 x L x D tensor that parameterizes the control variate;
L is an arbitrary positive integer.  This parameter needs to be learned (i.e. adapted) to
achieve lower variance gradients. In a typical use case this parameter will be adapted
concurrently with the loc and scale_tril that define the distribution.

Example usage::

control_var = torch.tensor(0.1 * torch.ones(2, 1, D), requires_grad=True)
opt_cv = torch.optim.Adam([control_var], lr=0.1, betas=(0.5, 0.999))

for _ in range(1000):
d = AVFMultivariateNormal(loc, scale_tril, control_var)
z = d.rsample()
cost = torch.pow(z, 2.0).sum()
cost.backward()
opt_cv.step()

"""

arg_constraints = {
"loc": constraints.real,
"scale_tril": constraints.lower_triangular,
"control_var": constraints.real,
}

def __init__(self, loc, scale_tril, control_var):
if loc.dim() != 1:
raise ValueError("AVFMultivariateNormal loc must be 1-dimensional")
if scale_tril.dim() != 2:
raise ValueError("AVFMultivariateNormal scale_tril must be 2-dimensional")
if (
control_var.dim() != 3
or control_var.size(0) != 2
or control_var.size(2) != loc.size(0)
):
raise ValueError(
"control_var should be of size 2 x L x D, where D is the dimension of the location parameter loc"
)  # noqa: E501
self.control_var = control_var
super().__init__(loc, scale_tril=scale_tril)

[docs]    def rsample(self, sample_shape=torch.Size()):
return _AVFMVNSample.apply(
self.loc, self.scale_tril, self.control_var, sample_shape + self.loc.shape
)

class _AVFMVNSample(Function):
@staticmethod
def forward(ctx, loc, scale_tril, control_var, shape):
white = torch.randn(shape, dtype=loc.dtype, device=loc.device)
z = torch.matmul(white, scale_tril.t())
ctx.save_for_backward(scale_tril, control_var, white)
return loc + z

@staticmethod
@once_differentiable
L, control_var, epsilon = ctx.saved_tensors
B, C = control_var

# compute the rep trick gradient
epsilon_jb = epsilon.unsqueeze(-2)
g_ja = g.unsqueeze(-1)
diff_L_ab = sum_leftmost(g_ja * epsilon_jb, -2)

# modulate the velocity fields with infinitesimal rotations, i.e. apply the control variate
gL = torch.matmul(g, L)
eps_gL_ab = sum_leftmost(gL.unsqueeze(-1) * epsilon.unsqueeze(-2), -2)
xi_ab = eps_gL_ab - eps_gL_ab.t()
BC_lab = B.unsqueeze(-1) * C.unsqueeze(-2)
diff_L_ab += (xi_ab.unsqueeze(0) * BC_lab).sum(0)