# Source code for pyro.distributions.gaussian_scale_mixture

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import math

import torch
from torch.distributions import Categorical, constraints

from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import sum_leftmost

[docs]class GaussianScaleMixture(TorchDistribution):
"""
Mixture of Normal distributions with zero mean and diagonal covariance
matrices.

That is, this distribution is a mixture with K components, where each
component distribution is a D-dimensional Normal distribution with zero
mean and a D-dimensional diagonal covariance matrix. The K different
covariance matrices are controlled by the parameters coord_scale and
component_scale.  That is, the covariance matrix of the k'th component is
given by

Sigma_ii = (component_scale_k * coord_scale_i) ** 2   (i = 1, ..., D)

where component_scale_k is a positive scale factor and coord_scale_i
are positive scale parameters shared between all K components. The mixture
weights are controlled by a K-dimensional vector of softmax logits,
component_logits. This distribution implements pathwise derivatives for
samples from the distribution. This distribution does not currently
support batched parameters.

See reference [1] for details on the implementations of the pathwise
derivative. Please consider citing this reference if you use the pathwise

[1] Pathwise Derivatives for Multivariate Distributions, Martin Jankowiak &
Theofanis Karaletsos. arXiv:1806.01856

Note that this distribution supports both even and odd dimensions, but the
former should be more a bit higher precision, since it doesn't use any erfs
in the backward call. Also note that this distribution does not support
D = 1.

:param torch.tensor coord_scale: D-dimensional vector of scales
:param torch.tensor component_logits: K-dimensional vector of logits
:param torch.tensor component_scale: K-dimensional vector of scale multipliers
"""

has_rsample = True
arg_constraints = {
"component_scale": constraints.positive,
"coord_scale": constraints.positive,
"component_logits": constraints.real,
}

def __init__(self, coord_scale, component_logits, component_scale):
self.dim = coord_scale.size(0)
if self.dim < 2:
raise NotImplementedError("This distribution does not support D = 1")
assert (
coord_scale.dim() == 1
), "The coord_scale parameter in GaussianScaleMixture should be D dimensional"
assert (
component_scale.dim() == 1
), "The component_scale parameter in GaussianScaleMixture should be K dimensional"
assert (
component_logits.dim() == 1
), "The component_logits parameter in GaussianScaleMixture should be K dimensional"
assert (
component_logits.shape == component_scale.shape
), "The component_logits and component_scale parameters in GaussianScaleMixture should be K dimensional"
self.coord_scale = coord_scale
self.component_logits = component_logits
self.component_scale = component_scale
self.coeffs = self._compute_coeffs()
self.categorical = Categorical(logits=component_logits)
super().__init__(event_shape=(self.dim,))

def _compute_coeffs(self):
"""
These coefficients are used internally in the backward call.
"""
dimov2 = int(self.dim / 2)  # this is correct for both even and odd dimensions
coeffs = torch.ones(dimov2)
for k in range(dimov2 - 1):
coeffs[k + 1 :] *= self.dim - 2 * (k + 1)
return coeffs

[docs]    def log_prob(self, value):
assert value.dim() == 1 and value.size(0) == self.dim
epsilon_sqr = torch.pow(value / self.coord_scale, 2.0).sum()
component_scale_log_power = self.component_scale.log() * -self.dim
# logits in Categorical is already normalized
result = torch.logsumexp(
component_scale_log_power
+ self.categorical.logits
+ -0.5 * epsilon_sqr / torch.pow(self.component_scale, 2.0),
dim=-1,
)  # K
result -= 0.5 * math.log(2.0 * math.pi) * float(self.dim)
result -= self.coord_scale.log().sum()
return result

[docs]    def rsample(self, sample_shape=torch.Size()):
which = self.categorical.sample(sample_shape)
return _GSMSample.apply(
self.coord_scale,
self.component_logits,
self.component_scale,
self.categorical.probs,
which,
sample_shape + torch.Size((self.dim,)),
self.coeffs,
)

class _GSMSample(Function):
@staticmethod
def forward(
ctx, coord_scale, component_logits, component_scale, pis, which, shape, coeffs
):
white = coord_scale.new(shape).normal_()
which_component_scale = component_scale[which].unsqueeze(-1)
z = coord_scale * which_component_scale * white
ctx.save_for_backward(
z, coord_scale, component_logits, component_scale, pis, coeffs
)
return z

@staticmethod
@once_differentiable
(
z,
coord_scale,
component_logits,
component_scale,
pis,
coeffs,
) = ctx.saved_tensors
dim = coord_scale.size(0)
g = grad_output  # l i
g = g.unsqueeze(-2)  # l 1 i

component_scale_sqr = torch.pow(component_scale, 2.0)  # j
epsilons = z / coord_scale  # l i
epsilons_sqr = torch.pow(epsilons, 2.0)  # l i
r_sqr = epsilons_sqr.sum(-1, keepdim=True)  # l
r_sqr_j = r_sqr / component_scale_sqr  # l j
coord_scale_product = coord_scale.prod()
component_scale_power = torch.pow(component_scale, float(dim))

q_j = torch.exp(-0.5 * r_sqr_j) / math.pow(
2.0 * math.pi, 0.5 * float(dim)
)  # l j
q_j /= coord_scale_product * component_scale_power  # l j
q_tot = (pis * q_j).sum(-1, keepdim=True)  # l

Phi_j = torch.exp(-0.5 * r_sqr_j)  # l j
exponents = -torch.arange(1.0, int(dim / 2) + 1.0, 1.0)
if z.dim() > 1:
r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, -1, int(dim / 2))  # l j d/2
else:
r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, int(dim / 2))  # l j d/2
r_j_poly = coeffs * torch.pow(r_j_poly, exponents)
Phi_j *= r_j_poly.sum(-1)
if dim % 2 == 1:
root_two = math.sqrt(2.0)
extra_term = (
coeffs[-1]
* math.sqrt(0.5 * math.pi)
* (1.0 - torch.erf(r_sqr_j.sqrt() / root_two))
)  # l j
Phi_j += extra_term * torch.pow(r_sqr_j, -0.5 * float(dim))

logits_grad = (z.unsqueeze(-2) * Phi_j.unsqueeze(-1) * g).sum(-1)  # l j
2.0 * math.pi, -0.5 * float(dim)
)

prefactor = (
pis.unsqueeze(-1) * q_j.unsqueeze(-1) * g / q_tot.unsqueeze(-1)
)  # l j i
coord_scale_grad = sum_leftmost(prefactor * epsilons.unsqueeze(-2), -1)