# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import copy
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, Hashable, Union
import torch
from pyro.ops.welford import WelfordCovariance
[docs]class StreamingStats(ABC):
"""
Abstract base class for streamable statistics of trees of tensors.
Derived classes must implelement :meth:`update`, :meth:`merge`, and
:meth:`get`.
"""
[docs] @abstractmethod
def update(self, sample) -> None:
"""
Update state from a single sample.
This mutates ``self`` and returns nothing. Updates should be
independent of order, i.e. samples should be exchangeable.
:param sample: A sample value which is a nested dictionary of
:class:`torch.Tensor` leaves. This can have arbitrary nesting and
shape shape, but assumes shape is constant across calls to
``.update()``.
"""
raise NotImplementedError
[docs] @abstractmethod
def merge(self, other) -> "StreamingStats":
"""
Select two aggregate statistics, e.g. from different MCMC chains.
This is a pure function: it returns a new :class:`StreamingStats`
object and does not modify either ``self`` or ``other``.
:param other: Another streaming stats instance of the same type.
"""
assert isinstance(other, type(self))
raise NotImplementedError
[docs] @abstractmethod
def get(self) -> Any:
"""
Return the aggregate statistic.
"""
raise NotImplementedError
[docs]class CountStats(StreamingStats):
"""
Statistic tracking only the number of samples.
For example::
>>> stats = CountStats()
>>> stats.update(torch.randn(3, 3))
>>> stats.get()
{'count': 1}
"""
def __init__(self):
self.count = 0
super().__init__()
[docs] def update(self, sample) -> None:
self.count += 1
[docs] def merge(self, other: "CountStats") -> "CountStats":
assert isinstance(other, type(self))
result = CountStats()
result.count = self.count + other.count
return result
[docs] def get(self) -> Dict[str, int]:
"""
:returns: A dictionary with keys ``count: int``.
:rtype: dict
"""
return {"count": self.count}
[docs]class StatsOfDict(StreamingStats):
"""
Statistics of samples that are dictionaries with constant set of keys.
For example the following are equivalent::
# Version 1. Hand encode statistics.
>>> a_stats = CountStats()
>>> b_stats = CountMeanStats()
>>> a_stats.update(torch.tensor(0.))
>>> b_stats.update(torch.tensor([1., 2.]))
>>> summary = {"a": a_stats.get(), "b": b_stats.get()}
# Version 2. Collect samples into dictionaries.
>>> stats = StatsOfDict({"a": CountStats, "b": CountMeanStats})
>>> stats.update({"a": torch.tensor(0.), "b": torch.tensor([1., 2.])})
>>> summary = stats.get()
>>> summary
{'a': {'count': 1}, 'b': {'count': 1, 'mean': tensor([1., 2.])}}
:param default: Default type of statistics of values of the dictionary.
Defaults to the inexpensive :class:`CountStats`.
:param dict types: Dictionary mapping key to type of statistic that should
be recorded for values corresponding to that key.
"""
def __init__(
self,
types: Dict[Hashable, Callable[[], StreamingStats]] = {},
default: Callable[[], StreamingStats] = CountStats,
):
self.stats: Dict[Hashable, StreamingStats] = defaultdict(default)
self.stats.update({k: v() for k, v in types.items()})
super().__init__()
[docs] def update(self, sample: Dict[Hashable, Any]) -> None:
for k, v in sample.items():
self.stats[k].update(v)
[docs] def merge(self, other: "StatsOfDict") -> "StatsOfDict":
assert isinstance(other, type(self))
result = copy.deepcopy(self)
for k in set(self.stats).union(other.stats):
if k not in self.stats:
result.stats[k] = copy.deepcopy(other.stats[k])
elif k in other.stats:
result.stats[k] = self.stats[k].merge(other.stats[k])
return result
[docs] def get(self) -> Dict[Hashable, Any]:
"""
:returns: A dictionary of statistics. The keys of this dictionary are
the same as the keys of the samples from which this object is
updated.
:rtype: dict
"""
return {k: v.get() for k, v in self.stats.items()}
[docs]class StackStats(StreamingStats):
"""
Statistic collecting a stream of tensors into a single stacked tensor.
"""
def __init__(self):
self.samples = []
[docs] def update(self, sample: torch.Tensor) -> None:
assert isinstance(sample, torch.Tensor)
self.samples.append(sample)
[docs] def merge(self, other: "StackStats") -> "StackStats":
assert isinstance(other, type(self))
result = StackStats()
result.samples = self.samples + other.samples
return result
[docs] def get(self) -> Dict[str, Union[int, torch.Tensor]]:
"""
:returns: A dictionary with keys ``count: int`` and (if any samples
have been collected) ``samples: torch.Tensor``.
:rtype: dict
"""
if not self.samples:
return {"count": 0}
return {"count": len(self.samples), "samples": torch.stack(self.samples)}
[docs]class CountMeanStats(StreamingStats):
"""
Statistic tracking the count and mean of a single :class:`torch.Tensor`.
"""
def __init__(self):
self.count = 0
self.mean = 0
super().__init__()
[docs] def update(self, sample: torch.Tensor) -> None:
assert isinstance(sample, torch.Tensor)
self.count += 1
self.mean += (sample.detach() - self.mean) / self.count
[docs] def merge(self, other: "CountMeanStats") -> "CountMeanStats":
assert isinstance(other, type(self))
result = CountMeanStats()
result.count = self.count + other.count
p = self.count / max(result.count, 1)
q = other.count / max(result.count, 1)
result.mean = p * self.mean + q * other.mean
return result
[docs] def get(self) -> Dict[str, Union[int, torch.Tensor]]:
"""
:returns: A dictionary with keys ``count: int`` and (if any samples
have been collected) ``mean: torch.Tensor``.
:rtype: dict
"""
if self.count == 0:
return {"count": 0}
return {"count": self.count, "mean": self.mean}
[docs]class CountMeanVarianceStats(StreamingStats):
"""
Statistic tracking the count, mean, and (diagonal) variance of a single
:class:`torch.Tensor`.
"""
def __init__(self):
self.shape = None
self.welford = WelfordCovariance(diagonal=True)
super().__init__()
[docs] def update(self, sample: torch.Tensor) -> None:
assert isinstance(sample, torch.Tensor)
if self.shape is None:
self.shape = sample.shape
assert sample.shape == self.shape
self.welford.update(sample.detach().reshape(-1))
[docs] def merge(self, other: "CountMeanVarianceStats") -> "CountMeanVarianceStats":
assert isinstance(other, type(self))
if self.shape is None:
return copy.deepcopy(other)
if other.shape is None:
return copy.deepcopy(self)
result = copy.deepcopy(self)
res = result.welford
lhs = self.welford
rhs = other.welford
res.n_samples = lhs.n_samples + rhs.n_samples
lhs_weight = lhs.n_samples / res.n_samples
rhs_weight = rhs.n_samples / res.n_samples
res._mean = lhs_weight * lhs._mean + rhs_weight * rhs._mean
res._m2 = (
lhs._m2
+ rhs._m2
+ (lhs.n_samples * rhs.n_samples / res.n_samples)
* (lhs._mean - rhs._mean) ** 2
)
return result
[docs] def get(self) -> Dict[str, Union[int, torch.Tensor]]:
"""
:returns: A dictionary with keys ``count: int`` and (if any samples
have been collected) ``mean: torch.Tensor`` and ``variance:
torch.Tensor``.
:rtype: dict
"""
if self.shape is None:
return {"count": 0}
count = self.welford.n_samples
mean = self.welford._mean.reshape(self.shape)
variance = self.welford.get_covariance(regularize=False).reshape(self.shape)
return {"count": count, "mean": mean, "variance": variance}
# Note this is ordered logically for sphinx rather than alphabetically.
__all__ = [
"StreamingStats",
"StatsOfDict",
"StackStats",
"CountStats",
"CountMeanStats",
"CountMeanVarianceStats",
]