# Source code for pyro.distributions.gaussian_scale_mixture

from __future__ import absolute_import, division, print_function
import math

import torch
[docs]class GaussianScaleMixture(TorchDistribution): """ Mixture of Normal distributions with zero mean and diagonal covariance matrices. That is, this distribution is a mixture with K components, where each component distribution is a D-dimensional Normal distribution with zero mean and a D-dimensional diagonal covariance matrix. The K different covariance matrices are controlled by the parameters coord_scale and component_scale. That is, the covariance matrix of the k'th component is given by Sigma_ii = (component_scale_k * coord_scale_i) ** 2 (i = 1, ..., D) where component_scale_k is a positive scale factor and coord_scale_i are positive scale parameters shared between all K components. The mixture weights are controlled by a K-dimensional vector of softmax logits, component_logits. This distribution implements pathwise derivatives for samples from the distribution. This distribution does not currently support batched parameters. See reference [1] for details on the implementations of the pathwise derivative. Please consider citing this reference if you use the pathwise derivative in your research. [1] Pathwise Derivatives for Multivariate Distributions, Martin Jankowiak & Theofanis Karaletsos. arXiv:1806.01856 Note that this distribution supports both even and odd dimensions, but the former should be more a bit higher precision, since it doesn't use any erfs in the backward call. Also note that this distribution does not support D = 1. :param torch.tensor coord_scale: D-dimensional vector of scales :param torch.tensor component_logits: K-dimensional vector of logits :param torch.tensor component_scale: K-dimensional vector of scale multipliers """ has_rsample = True arg_constraints = {"component_scale": constraints.positive, "coord_scale": constraints.positive, "component_logits": constraints.real} def __init__(self, coord_scale, component_logits, component_scale): self.dim = coord_scale.size(0) if self.dim < 2: raise NotImplementedError('This distribution does not support D = 1') assert(coord_scale.dim() == 1), "The coord_scale parameter in GaussianScaleMixture should be D dimensional" assert(component_scale.dim() == 1), \ "The component_scale parameter in GaussianScaleMixture should be K dimensional" assert(component_logits.dim() == 1), \ "The component_logits parameter in GaussianScaleMixture should be K dimensional" assert(component_logits.shape == component_scale.shape), \ "The component_logits and component_scale parameters in GaussianScaleMixture should be K dimensional" self.coord_scale = coord_scale self.component_logits = component_logits self.component_scale = component_scale self.coeffs = self._compute_coeffs() self.categorical = Categorical(logits=component_logits) super(GaussianScaleMixture, self).__init__(event_shape=(self.dim,)) def _compute_coeffs(self): """ These coefficients are used internally in the backward call. """ dimov2 = int(self.dim / 2) # this is correct for both even and odd dimensions coeffs = torch.ones(dimov2) for k in range(dimov2 - 1): coeffs[k + 1:] *= self.dim - 2 * (k + 1) return coeffs