Source code for pyro.distributions.empirical

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints

from pyro.distributions.torch import Categorical
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import copy_docs_from


[docs]@copy_docs_from(TorchDistribution) class Empirical(TorchDistribution): r""" Empirical distribution associated with the sampled data. Note that the shape requirement for `log_weights` is that its shape must match the leftmost shape of `samples`. Samples are aggregated along the ``aggregation_dim``, which is the rightmost dim of `log_weights`. Example: >>> emp_dist = Empirical(torch.randn(2, 3, 10), torch.ones(2, 3)) >>> emp_dist.batch_shape torch.Size([2]) >>> emp_dist.event_shape torch.Size([10]) >>> single_sample = emp_dist.sample() >>> single_sample.shape torch.Size([2, 10]) >>> batch_sample = emp_dist.sample((100,)) >>> batch_sample.shape torch.Size([100, 2, 10]) >>> emp_dist.log_prob(single_sample).shape torch.Size([2]) >>> # Vectorized samples cannot be scored by log_prob. >>> with pyro.validation_enabled(): ... emp_dist.log_prob(batch_sample).shape Traceback (most recent call last): ... ValueError: ``value.shape`` must be torch.Size([2, 10]) :param torch.Tensor samples: samples from the empirical distribution. :param torch.Tensor log_weights: log weights (optional) corresponding to the samples. """ arg_constraints = {} support = constraints.real has_enumerate_support = True def __init__(self, samples, log_weights, validate_args=None): self._samples = samples self._log_weights = log_weights sample_shape, weight_shape = samples.size(), log_weights.size() if ( weight_shape > sample_shape or weight_shape != sample_shape[: len(weight_shape)] ): raise ValueError( "The shape of ``log_weights`` ({}) must match " "the leftmost shape of ``samples`` ({})".format( weight_shape, sample_shape ) ) self._aggregation_dim = log_weights.dim() - 1 event_shape = sample_shape[len(weight_shape) :] self._categorical = Categorical(logits=self._log_weights) super().__init__( batch_shape=weight_shape[:-1], event_shape=event_shape, validate_args=validate_args, ) @property def sample_size(self): """ Number of samples that constitute the empirical distribution. :return int: number of samples collected. """ return self._log_weights.numel()
[docs] def sample(self, sample_shape=torch.Size()): sample_idx = self._categorical.sample( sample_shape ) # sample_shape x batch_shape # reorder samples to bring aggregation_dim to the front: # batch_shape x num_samples x event_shape -> num_samples x batch_shape x event_shape samples = ( self._samples.unsqueeze(0) .transpose(0, self._aggregation_dim + 1) .squeeze(self._aggregation_dim + 1) ) # make sample_idx.shape compatible with samples.shape: sample_shape_numel x batch_shape x event_shape sample_idx = sample_idx.reshape( (-1,) + self.batch_shape + (1,) * len(self.event_shape) ) sample_idx = sample_idx.expand((-1,) + samples.shape[1:]) return samples.gather(0, sample_idx).reshape(sample_shape + samples.shape[1:])
[docs] def log_prob(self, value): """ Returns the log of the probability mass function evaluated at ``value``. Note that this currently only supports scoring values with empty ``sample_shape``. :param torch.Tensor value: scalar or tensor value to be scored. """ if self._validate_args: if value.shape != self.batch_shape + self.event_shape: raise ValueError( "``value.shape`` must be {}".format( self.batch_shape + self.event_shape ) ) if self.batch_shape: value = value.unsqueeze(self._aggregation_dim) selection_mask = self._samples.eq(value) # Get a mask for all entries in the ``weights`` tensor # that correspond to ``value``. for _ in range(len(self.event_shape)): selection_mask = selection_mask.min(dim=-1)[0] selection_mask = selection_mask.type(self._categorical.probs.type()) return (self._categorical.probs * selection_mask).sum(dim=-1).log()
def _weighted_mean(self, value, keepdim=False): weights = self._log_weights.reshape( self._log_weights.size() + torch.Size([1] * (value.dim() - self._log_weights.dim())) ) dim = self._aggregation_dim max_weight = weights.max(dim=dim, keepdim=True)[0] relative_probs = (weights - max_weight).exp() return (value * relative_probs).sum( dim=dim, keepdim=keepdim ) / relative_probs.sum(dim=dim, keepdim=keepdim) @property def event_shape(self): return self._event_shape @property def mean(self): if self._samples.dtype in (torch.int32, torch.int64): raise ValueError( "Mean for discrete empirical distribution undefined. " + "Consider converting samples to ``torch.float32`` " + "or ``torch.float64``. If these are samples from a " + "`Categorical` distribution, consider converting to a " + "`OneHotCategorical` distribution." ) return self._weighted_mean(self._samples) @property def variance(self): if self._samples.dtype in (torch.int32, torch.int64): raise ValueError( "Variance for discrete empirical distribution undefined. " + "Consider converting samples to ``torch.float32`` " + "or ``torch.float64``. If these are samples from a " + "`Categorical` distribution, consider converting to a " + "`OneHotCategorical` distribution." ) mean = self.mean.unsqueeze(self._aggregation_dim) deviation_squared = torch.pow(self._samples - mean, 2) return self._weighted_mean(deviation_squared) @property def log_weights(self): return self._log_weights
[docs] def enumerate_support(self, expand=True): # Empirical does not support batching, so expanding is a no-op. return self._samples