Source code for pyro.distributions.grouped_normal_normal

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

import torch
from torch.distributions.utils import broadcast_all

from pyro.distributions import Normal, constraints
from pyro.distributions.torch_distribution import TorchDistribution

LOG_ROOT_TWO_PI = 0.5 * math.log(2.0 * math.pi)

[docs]class GroupedNormalNormal(TorchDistribution): r""" This likelihood, which operates on groups of real-valued scalar observations, is obtained by integrating out a latent mean for each group. Both the prior on each latent mean as well as the observation likelihood for each data point are univariate Normal distributions. The prior means are controlled by `prior_loc` and `prior_scale`. The observation noise of the Normal likelihood is controlled by `obs_scale`, which is allowed to vary from observation to observation. The tensor of indices `group_idx` connects each observation to one of the groups specified by `prior_loc` and `prior_scale`. See e.g. Eqn. (55) in ref. [1] for relevant expressions in a simpler case with scalar `obs_scale`. Example: >>> num_groups = 3 >>> num_data = 4 >>> prior_loc = torch.randn(num_groups) >>> prior_scale = torch.rand(num_groups) >>> obs_scale = torch.rand(num_data) >>> group_idx = torch.tensor([1, 0, 2, 1]).long() >>> values = torch.randn(num_data) >>> gnn = GroupedNormalNormal(prior_loc, prior_scale, obs_scale, group_idx) >>> assert gnn.log_prob(values).shape == () References: [1] "Conjugate Bayesian analysis of the Gaussian distribution," Kevin P. Murphy. :param torch.Tensor prior_loc: Tensor of shape `(num_groups,)` specifying the prior mean of the latent of each group. :param torch.Tensor prior_scale: Tensor of shape `(num_groups,)` specifying the prior scale of the latent of each group. :param torch.Tensor obs_scale: Tensor of shape `(num_data,)` specifying the scale of the observation noise of each observation. :param torch.LongTensor group_idx: Tensor of indices of shape `(num_data,)` linking each observation to one of the `num_groups` groups that are specified in `prior_loc` and `prior_scale`. """ arg_constraints = { "prior_loc": constraints.real, "prior_scale": constraints.positive, "obs_scale": constraints.positive, } support = constraints.real def __init__( self, prior_loc, prior_scale, obs_scale, group_idx, validate_args=None ): if prior_loc.ndim not in [0, 1] or prior_scale.ndim not in [0, 1]: raise ValueError( "prior_loc and prior_scale must be broadcastable to 1D tensors of the same shape." ) if obs_scale.ndim not in [0, 1]: raise ValueError( "obs_scale must be broadcastable to a 1-dimensional tensor." ) if group_idx.ndim != 1 or not isinstance(group_idx, torch.LongTensor): raise ValueError("group_idx must be a 1-dimensional tensor of indices.") prior_loc, prior_scale = broadcast_all(prior_loc, prior_scale) obs_scale, group_idx = broadcast_all(obs_scale, group_idx) self.prior_loc = prior_loc self.prior_scale = prior_scale self.obs_scale = obs_scale self.group_idx = group_idx batch_shape = prior_loc.shape[:-1] if batch_shape != torch.Size([]): raise ValueError("GroupedNormalNormal only supports trivial batch_shape's.") self.num_groups = prior_loc.size(0) if group_idx.min().item() < 0 or group_idx.max().item() >= self.num_groups: raise ValueError( "Each index in group_idx must be an integer in the inclusive range [0, prior_loc.size(0) - 1]." ) self.num_data_per_batch = prior_loc.new_zeros(self.num_groups).scatter_add( 0, self.group_idx, prior_loc.new_ones(self.group_idx.shape) ) super().__init__(batch_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): raise NotImplementedError
[docs] def sample(self, sample_shape=()): raise NotImplementedError
[docs] def get_posterior(self, value): """ Get a `pyro.distributions.Normal` distribution that encodes the posterior distribution over the vector of latents specified by `prior_loc` and `prior_scale` conditioned on the observed data specified by `value`. """ if value.shape != self.group_idx.shape: raise ValueError( "GroupedNormalNormal.get_posterior only supports values that have the same shape as group_idx." ) obs_scale_sq_inv = self.obs_scale.pow(-2) prior_scale_sq_inv = self.prior_scale.pow(-2) obs_scale_sq_inv_sum = torch.zeros_like(self.prior_loc).scatter_add( 0, self.group_idx, obs_scale_sq_inv ) precision = prior_scale_sq_inv + obs_scale_sq_inv_sum scaled_value_sum = torch.zeros_like(self.prior_loc).scatter_add( 0, self.group_idx, value * obs_scale_sq_inv ) loc = (scaled_value_sum + self.prior_loc * prior_scale_sq_inv) / precision scale = precision.rsqrt() return Normal(loc=loc, scale=scale)
[docs] def log_prob(self, value): if self._validate_args: self._validate_sample(value) group_idx = self.group_idx if value.shape != group_idx.shape: raise ValueError( "GroupedNormalNormal.log_prob only supports values that have the same shape as group_idx." ) prior_scale_sq = self.prior_scale.pow(2.0) obs_scale_sq_inv = self.obs_scale.pow(-2) obs_scale_sq_inv_sum = torch.zeros_like(self.prior_loc).scatter_add( 0, self.group_idx, obs_scale_sq_inv ) scale_ratio = prior_scale_sq * obs_scale_sq_inv_sum delta = value - self.prior_loc[group_idx] scaled_delta = delta * obs_scale_sq_inv scaled_delta_sum = torch.zeros_like(self.prior_loc).scatter_add( 0, self.group_idx, scaled_delta ) result1 = -(self.num_data_per_batch * LOG_ROOT_TWO_PI).sum() result2 = -0.5 * torch.log1p(scale_ratio).sum() - self.obs_scale.log().sum() result3 = -0.5 *, scaled_delta) numerator = prior_scale_sq * scaled_delta_sum.pow(2) result4 = 0.5 * (numerator / (1.0 + scale_ratio)).sum() return result1 + result2 + result3 + result4