import torch
[docs]class WelfordCovariance(object):
"""
Implements Welford's online scheme for estimating (co)variance (see :math:`[1]`).
Useful for adapting diagonal and dense mass structures for HMC.
**References**
[1] `The Art of Computer Programming`,
Donald E. Knuth
"""
def __init__(self, diagonal=True):
self.diagonal = diagonal
self.reset()
[docs] def reset(self):
self._mean = 0.
self._m2 = 0.
self.n_samples = 0
[docs] def update(self, sample):
self.n_samples += 1
delta_pre = sample - self._mean
self._mean = self._mean + delta_pre / self.n_samples
delta_post = sample - self._mean
if self.diagonal:
self._m2 += delta_pre * delta_post
else:
self._m2 += torch.ger(delta_post, delta_pre)
[docs] def get_covariance(self, regularize=True):
if self.n_samples < 2:
raise RuntimeError('Insufficient samples to estimate covariance')
cov = self._m2 / (self.n_samples - 1)
if regularize:
# Regularization from stan
scaled_cov = (self.n_samples / (self.n_samples + 5.)) * cov
shrinkage = 1e-3 * (5. / (self.n_samples + 5.0))
if self.diagonal:
cov = scaled_cov + shrinkage
else:
scaled_cov.view(-1)[::scaled_cov.size(0) + 1] += shrinkage
cov = scaled_cov
return cov