# Source code for pyro.ops.welford

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch

[docs]class WelfordCovariance:
"""
Implements Welford's online scheme for estimating (co)variance (see :math:[1]).
Useful for adapting diagonal and dense mass structures for HMC.

**References**

[1] The Art of Computer Programming,
Donald E. Knuth
"""

def __init__(self, diagonal=True):
self.diagonal = diagonal
self.reset()

[docs]    def reset(self):
self._mean = 0.0
self._m2 = 0.0
self.n_samples = 0

[docs]    def update(self, sample):
self.n_samples += 1
delta_pre = sample - self._mean
self._mean = self._mean + delta_pre / self.n_samples
delta_post = sample - self._mean

if self.diagonal:
self._m2 += delta_pre * delta_post
else:
self._m2 += torch.ger(delta_post, delta_pre)

[docs]    def get_covariance(self, regularize=True):
if self.n_samples < 2:
raise RuntimeError("Insufficient samples to estimate covariance")
cov = self._m2 / (self.n_samples - 1)
if regularize:
# Regularization from stan
scaled_cov = (self.n_samples / (self.n_samples + 5.0)) * cov
shrinkage = 1e-3 * (5.0 / (self.n_samples + 5.0))
if self.diagonal:
cov = scaled_cov + shrinkage
else:
scaled_cov.view(-1)[:: scaled_cov.size(0) + 1] += shrinkage
cov = scaled_cov
return cov

[docs]class WelfordArrowheadCovariance:
"""
Likes :class:WelfordCovariance but generalized to the arrowhead structure.
"""

def __init__(self, head_size=0):
self.head_size = head_size
self.reset()

[docs]    def reset(self):
self._mean = 0.0
self._m2_top = 0.0  # upper part, shape: head_size x matrix_size
self._m2_bottom_diag = 0.0  # lower right part, shape: (matrix_size - head_size)
self.n_samples = 0

[docs]    def update(self, sample):
self.n_samples += 1
delta_pre = sample - self._mean
self._mean = self._mean + delta_pre / self.n_samples
delta_post = sample - self._mean
if self.head_size > 0:
self._m2_top = self._m2_top + torch.ger(
delta_post[: self.head_size], delta_pre
)
else:
self._m2_top = sample.new_empty(0, sample.size(0))
self._m2_bottom_diag = (
self._m2_bottom_diag
+ delta_post[self.head_size :] * delta_pre[self.head_size :]
)

[docs]    def get_covariance(self, regularize=True):
"""
Gets the covariance in arrowhead form: (top, bottom_diag) where top = cov[:head_size]
and bottom_diag = cov.diag()[head_size:].
"""
if self.n_samples < 2:
raise RuntimeError("Insufficient samples to estimate covariance")
top = self._m2_top / (self.n_samples - 1)
bottom_diag = self._m2_bottom_diag / (self.n_samples - 1)
if regularize:
top = top * (self.n_samples / (self.n_samples + 5.0))
bottom_diag = bottom_diag * (self.n_samples / (self.n_samples + 5.0))
shrinkage = 1e-3 * (5.0 / (self.n_samples + 5.0))
top.view(-1)[:: top.size(-1) + 1] += shrinkage
bottom_diag = bottom_diag + shrinkage

return top, bottom_diag