# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
from torch.distributions import constraints
from pyro.distributions.torch import Categorical
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import copy_docs_from
[docs]@copy_docs_from(TorchDistribution)
class Empirical(TorchDistribution):
r"""
Empirical distribution associated with the sampled data. Note that the shape
requirement for `log_weights` is that its shape must match the leftmost shape
of `samples`. Samples are aggregated along the ``aggregation_dim``, which is
the rightmost dim of `log_weights`.
Example:
>>> emp_dist = Empirical(torch.randn(2, 3, 10), torch.ones(2, 3))
>>> emp_dist.batch_shape
torch.Size([2])
>>> emp_dist.event_shape
torch.Size([10])
>>> single_sample = emp_dist.sample()
>>> single_sample.shape
torch.Size([2, 10])
>>> batch_sample = emp_dist.sample((100,))
>>> batch_sample.shape
torch.Size([100, 2, 10])
>>> emp_dist.log_prob(single_sample).shape
torch.Size([2])
>>> # Vectorized samples cannot be scored by log_prob.
>>> with pyro.validation_enabled():
... emp_dist.log_prob(batch_sample).shape
Traceback (most recent call last):
...
ValueError: ``value.shape`` must be torch.Size([2, 10])
:param torch.Tensor samples: samples from the empirical distribution.
:param torch.Tensor log_weights: log weights (optional) corresponding
to the samples.
"""
arg_constraints = {}
support = constraints.real
has_enumerate_support = True
def __init__(self, samples, log_weights, validate_args=None):
self._samples = samples
self._log_weights = log_weights
sample_shape, weight_shape = samples.size(), log_weights.size()
if (
weight_shape > sample_shape
or weight_shape != sample_shape[: len(weight_shape)]
):
raise ValueError(
"The shape of ``log_weights`` ({}) must match "
"the leftmost shape of ``samples`` ({})".format(
weight_shape, sample_shape
)
)
self._aggregation_dim = log_weights.dim() - 1
event_shape = sample_shape[len(weight_shape) :]
self._categorical = Categorical(logits=self._log_weights)
super().__init__(
batch_shape=weight_shape[:-1],
event_shape=event_shape,
validate_args=validate_args,
)
@property
def sample_size(self):
"""
Number of samples that constitute the empirical distribution.
:return int: number of samples collected.
"""
return self._log_weights.numel()
[docs] def sample(self, sample_shape=torch.Size()):
sample_idx = self._categorical.sample(
sample_shape
) # sample_shape x batch_shape
# reorder samples to bring aggregation_dim to the front:
# batch_shape x num_samples x event_shape -> num_samples x batch_shape x event_shape
samples = (
self._samples.unsqueeze(0)
.transpose(0, self._aggregation_dim + 1)
.squeeze(self._aggregation_dim + 1)
)
# make sample_idx.shape compatible with samples.shape: sample_shape_numel x batch_shape x event_shape
sample_idx = sample_idx.reshape(
(-1,) + self.batch_shape + (1,) * len(self.event_shape)
)
sample_idx = sample_idx.expand((-1,) + samples.shape[1:])
return samples.gather(0, sample_idx).reshape(sample_shape + samples.shape[1:])
[docs] def log_prob(self, value):
"""
Returns the log of the probability mass function evaluated at ``value``.
Note that this currently only supports scoring values with empty
``sample_shape``.
:param torch.Tensor value: scalar or tensor value to be scored.
"""
if self._validate_args:
if value.shape != self.batch_shape + self.event_shape:
raise ValueError(
"``value.shape`` must be {}".format(
self.batch_shape + self.event_shape
)
)
if self.batch_shape:
value = value.unsqueeze(self._aggregation_dim)
selection_mask = self._samples.eq(value)
# Get a mask for all entries in the ``weights`` tensor
# that correspond to ``value``.
for _ in range(len(self.event_shape)):
selection_mask = selection_mask.min(dim=-1)[0]
selection_mask = selection_mask.type(self._categorical.probs.type())
return (self._categorical.probs * selection_mask).sum(dim=-1).log()
def _weighted_mean(self, value, keepdim=False):
weights = self._log_weights.reshape(
self._log_weights.size()
+ torch.Size([1] * (value.dim() - self._log_weights.dim()))
)
dim = self._aggregation_dim
max_weight = weights.max(dim=dim, keepdim=True)[0]
relative_probs = (weights - max_weight).exp()
return (value * relative_probs).sum(
dim=dim, keepdim=keepdim
) / relative_probs.sum(dim=dim, keepdim=keepdim)
@property
def event_shape(self):
return self._event_shape
@property
def mean(self):
if self._samples.dtype in (torch.int32, torch.int64):
raise ValueError(
"Mean for discrete empirical distribution undefined. "
+ "Consider converting samples to ``torch.float32`` "
+ "or ``torch.float64``. If these are samples from a "
+ "`Categorical` distribution, consider converting to a "
+ "`OneHotCategorical` distribution."
)
return self._weighted_mean(self._samples)
@property
def variance(self):
if self._samples.dtype in (torch.int32, torch.int64):
raise ValueError(
"Variance for discrete empirical distribution undefined. "
+ "Consider converting samples to ``torch.float32`` "
+ "or ``torch.float64``. If these are samples from a "
+ "`Categorical` distribution, consider converting to a "
+ "`OneHotCategorical` distribution."
)
mean = self.mean.unsqueeze(self._aggregation_dim)
deviation_squared = torch.pow(self._samples - mean, 2)
return self._weighted_mean(deviation_squared)
@property
def log_weights(self):
return self._log_weights
[docs] def enumerate_support(self, expand=True):
# Empirical does not support batching, so expanding is a no-op.
return self._samples