# Source code for pyro.distributions.gaussian_scale_mixture

from __future__ import absolute_import, division, print_function
import math

import torch
from torch.distributions import constraints, Categorical

from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import sum_leftmost

[docs]class GaussianScaleMixture(TorchDistribution):
"""
Mixture of Normal distributions with zero mean and diagonal covariance
matrices.

That is, this distribution is a mixture with K components, where each
component distribution is a D-dimensional Normal distribution with zero
mean and a D-dimensional diagonal covariance matrix. The K different
covariance matrices are controlled by the parameters coord_scale and
component_scale.  That is, the covariance matrix of the k'th component is
given by

Sigma_ii = (component_scale_k * coord_scale_i) ** 2   (i = 1, ..., D)

where component_scale_k is a positive scale factor and coord_scale_i
are positive scale parameters shared between all K components. The mixture
weights are controlled by a K-dimensional vector of softmax logits,
component_logits. This distribution implements pathwise derivatives for
samples from the distribution.  This distribution does not currently
support batched parameters.

See reference [1] for details on the implementations of the pathwise
derivative. Please consider citing this reference if you use the pathwise

[1] Pathwise Derivatives for Multivariate Distributions, Martin Jankowiak &
Theofanis Karaletsos. arXiv:1806.01856

Note that this distribution supports both even and odd dimensions, but the
former should be more a bit higher precision, since it doesn't use any erfs
in the backward call. Also note that this distribution does not support
D = 1.

:param torch.tensor coord_scale: D-dimensional vector of scales
:param torch.tensor component_logits: K-dimensional vector of logits
:param torch.tensor component_scale: K-dimensional vector of scale multipliers
"""
has_rsample = True
arg_constraints = {"component_scale": constraints.positive, "coord_scale": constraints.positive,
"component_logits": constraints.real}

def __init__(self, coord_scale, component_logits, component_scale):
self.dim = coord_scale.size(0)
if self.dim < 2:
raise NotImplementedError('This distribution does not support D = 1')
assert(coord_scale.dim() == 1), "The coord_scale parameter in GaussianScaleMixture should be D dimensional"
assert(component_scale.dim() == 1), \
"The component_scale parameter in GaussianScaleMixture should be K dimensional"
assert(component_logits.dim() == 1), \
"The component_logits parameter in GaussianScaleMixture should be K dimensional"
assert(component_logits.shape == component_scale.shape), \
"The component_logits and component_scale parameters in GaussianScaleMixture should be K dimensional"
self.coord_scale = coord_scale
self.component_logits = component_logits
self.component_scale = component_scale
self.coeffs = self._compute_coeffs()
self.categorical = Categorical(logits=component_logits)
super(GaussianScaleMixture, self).__init__(event_shape=(self.dim,))

def _compute_coeffs(self):
"""
These coefficients are used internally in the backward call.
"""
dimov2 = int(self.dim / 2)  # this is correct for both even and odd dimensions
coeffs = torch.ones(dimov2)
for k in range(dimov2 - 1):
coeffs[k + 1:] *= self.dim - 2 * (k + 1)
return coeffs

[docs]    def log_prob(self, value):
# TODO: use torch.logsumexp once it's in PyTorch release
assert value.dim() == 1 and value.size(0) == self.dim
epsilon_sqr = torch.pow(value / self.coord_scale, 2.0).sum()
component_scale_power = torch.pow(self.component_scale, -self.dim)
result = component_scale_power * self.categorical.probs * \
torch.exp(-0.5 * epsilon_sqr / torch.pow(self.component_scale, 2.0))  # K
result = torch.log(result.sum())
result -= 0.5 * math.log(2.0 * math.pi) * float(self.dim)
result -= torch.log(self.coord_scale).sum()
return result

[docs]    def rsample(self, sample_shape=torch.Size()):
which = self.categorical.sample(sample_shape)
return _GSMSample.apply(self.coord_scale, self.component_logits, self.component_scale, self.categorical.probs,
which, sample_shape + torch.Size((self.dim,)), self.coeffs)

class _GSMSample(Function):
@staticmethod
def forward(ctx, coord_scale, component_logits, component_scale, pis, which, shape, coeffs):
white = coord_scale.new(shape).normal_()
which_component_scale = component_scale[which].unsqueeze(-1)
z = coord_scale * which_component_scale * white
ctx.save_for_backward(z, coord_scale, component_logits, component_scale, pis, coeffs)
return z

@staticmethod
@once_differentiable
z, coord_scale, component_logits, component_scale, pis, coeffs = ctx.saved_tensors
dim = coord_scale.size(0)
g = grad_output  # l i
g = g.unsqueeze(-2)  # l 1 i

component_scale_sqr = torch.pow(component_scale, 2.0)  # j
epsilons = z / coord_scale  # l i
epsilons_sqr = torch.pow(epsilons, 2.0)  # l i
r_sqr = epsilons_sqr.sum(-1, keepdim=True)  # l
r_sqr_j = r_sqr / component_scale_sqr  # l j
coord_scale_product = coord_scale.prod()
component_scale_power = torch.pow(component_scale, float(dim))

q_j = torch.exp(-0.5 * r_sqr_j) / math.pow(2.0 * math.pi, 0.5 * float(dim))  # l j
q_j /= coord_scale_product * component_scale_power  # l j
q_tot = (pis * q_j).sum(-1, keepdim=True)  # l

Phi_j = torch.exp(-0.5 * r_sqr_j)  # l j
exponents = - torch.arange(1., int(dim/2) + 1., 1.)
if z.dim() > 1:
r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, -1, int(dim/2))  # l j d/2
else:
r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, int(dim/2))  # l j d/2
r_j_poly = coeffs * torch.pow(r_j_poly, exponents)
Phi_j *= r_j_poly.sum(-1)
if dim % 2 == 1:
root_two = math.sqrt(2.0)
extra_term = coeffs[-1] * math.sqrt(0.5 * math.pi) * (1.0 - torch.erf(r_sqr_j.sqrt() / root_two))  # l j
Phi_j += extra_term * torch.pow(r_sqr_j, -0.5 * float(dim))

logits_grad = (z.unsqueeze(-2) * Phi_j.unsqueeze(-1) * g).sum(-1)  # l j