Source code for

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
import torch.distributions as torchdist
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from import GPModel
from import conditional
from pyro.nn.module import PyroParam, pyro_method
from pyro.util import warn_if_nan

[docs]class GPRegression(GPModel): r""" Gaussian Process Regression model. The core of a Gaussian Process is a covariance function :math:`k` which governs the similarity between input points. Given :math:`k`, we can establish a distribution over functions :math:`f` by a multivarite normal distribution .. math:: p(f(X)) = \mathcal{N}(0, k(X, X)), where :math:`X` is any set of input points and :math:`k(X, X)` is a covariance matrix whose entries are outputs :math:`k(x, z)` of :math:`k` over input pairs :math:`(x, z)`. This distribution is usually denoted by .. math:: f \sim \mathcal{GP}(0, k). .. note:: Generally, beside a covariance matrix :math:`k`, a Gaussian Process can also be specified by a mean function :math:`m` (which is a zero-value function by default). In that case, its distribution will be .. math:: p(f(X)) = \mathcal{N}(m(X), k(X, X)). Given inputs :math:`X` and their noisy observations :math:`y`, the Gaussian Process Regression model takes the form .. math:: f &\sim \mathcal{GP}(0, k(X, X)),\\ y & \sim f + \epsilon, where :math:`\epsilon` is Gaussian noise. .. note:: This model has :math:`\mathcal{O}(N^3)` complexity for training, :math:`\mathcal{O}(N^3)` complexity for testing. Here, :math:`N` is the number of train inputs. Reference: [1] `Gaussian Processes for Machine Learning`, Carl E. Rasmussen, Christopher K. I. Williams :param torch.Tensor X: A input data for training. Its first dimension is the number of data points. :param torch.Tensor y: An output data for training. Its last dimension is the number of data points. :param kernel: A Pyro kernel object, which is the covariance function :math:`k`. :param torch.Tensor noise: Variance of Gaussian noise of this model. :param callable mean_function: An optional mean function :math:`m` of this Gaussian process. By default, we use zero mean. :param float jitter: A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition. """ def __init__(self, X, y, kernel, noise=None, mean_function=None, jitter=1e-6): assert isinstance( X, torch.Tensor ), "X needs to be a torch Tensor instead of a {}".format(type(X)) if y is not None: assert isinstance( y, torch.Tensor ), "y needs to be a torch Tensor instead of a {}".format(type(y)) super().__init__(X, y, kernel, mean_function, jitter) noise = self.X.new_tensor(1.0) if noise is None else noise self.noise = PyroParam(noise, constraints.positive)
[docs] @pyro_method def model(self): self.set_mode("model") N = self.X.size(0) Kff = self.kernel(self.X) Kff.view(-1)[:: N + 1] += self.jitter + self.noise # add noise to diagonal Lff = torch.linalg.cholesky(Kff) zero_loc = self.X.new_zeros(self.X.size(0)) f_loc = zero_loc + self.mean_function(self.X) if self.y is None: f_var = Lff.pow(2).sum(dim=-1) return f_loc, f_var else: return pyro.sample( self._pyro_get_fullname("y"), dist.MultivariateNormal(f_loc, scale_tril=Lff) .expand_by(self.y.shape[:-1]) .to_event(self.y.dim() - 1), obs=self.y, )
[docs] @pyro_method def guide(self): self.set_mode("guide") self._load_pyro_samples()
[docs] def forward(self, Xnew, full_cov=False, noiseless=True): r""" Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data :math:`X_{new}`: .. math:: p(f^* \mid X_{new}, X, y, k, \epsilon) = \mathcal{N}(loc, cov). .. note:: The noise parameter ``noise`` (:math:`\epsilon`) together with kernel's parameters have been learned from a training procedure (MCMC or SVI). :param torch.Tensor Xnew: A input data for testing. Note that ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``. :param bool full_cov: A flag to decide if we want to predict full covariance matrix or just variance. :param bool noiseless: A flag to decide if we want to include noise in the prediction output or not. :returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))` :rtype: tuple(torch.Tensor, torch.Tensor) """ self._check_Xnew_shape(Xnew) self.set_mode("guide") N = self.X.size(0) Kff = self.kernel(self.X).contiguous() Kff.view(-1)[:: N + 1] += self.jitter + self.noise # add noise to the diagonal Lff = torch.linalg.cholesky(Kff) y_residual = self.y - self.mean_function(self.X) loc, cov = conditional( Xnew, self.X, self.kernel, y_residual, None, Lff, full_cov, jitter=self.jitter, ) if full_cov and not noiseless: M = Xnew.size(0) cov = cov.contiguous() cov.view(-1, M * M)[:, :: M + 1] += self.noise # add noise to the diagonal if not full_cov and not noiseless: cov = cov + self.noise return loc + self.mean_function(Xnew), cov
[docs] def iter_sample(self, noiseless=True): r""" Iteratively constructs a sample from the Gaussian Process posterior. Recall that at test input points :math:`X_{new}`, the posterior is multivariate Gaussian distributed with mean and covariance matrix given by :func:`forward`. This method samples lazily from this multivariate Gaussian. The advantage of this approach is that later query points can depend upon earlier ones. Particularly useful when the querying is to be done by an optimisation routine. .. note:: The noise parameter ``noise`` (:math:`\epsilon`) together with kernel's parameters have been learned from a training procedure (MCMC or SVI). :param bool noiseless: A flag to decide if we want to add sampling noise to the samples beyond the noise inherent in the GP posterior. :returns: sampler :rtype: function """ noise = self.noise.detach() X = self.X.clone().detach() y = self.y.clone().detach() N = X.size(0) Kff = self.kernel(X).contiguous() Kff.view(-1)[:: N + 1] += noise # add noise to the diagonal outside_vars = {"X": X, "y": y, "N": N, "Kff": Kff} def sample_next(xnew, outside_vars): """Repeatedly samples from the Gaussian process posterior, conditioning on previously sampled values. """ warn_if_nan(xnew) # Variables from outer scope X, y, Kff = outside_vars["X"], outside_vars["y"], outside_vars["Kff"] # Compute Cholesky decomposition of kernel matrix Lff = torch.linalg.cholesky(Kff) y_residual = y - self.mean_function(X) # Compute conditional mean and variance loc, cov = conditional( xnew, X, self.kernel, y_residual, None, Lff, False, jitter=self.jitter ) if not noiseless: cov = cov + noise ynew = torchdist.Normal( loc + self.mean_function(xnew), cov.sqrt() ).rsample() # Update kernel matrix N = outside_vars["N"] Kffnew = Kff.new_empty(N + 1, N + 1) Kffnew[:N, :N] = Kff cross = self.kernel(X, xnew).squeeze() end = self.kernel(xnew, xnew).squeeze() Kffnew[N, :N] = cross Kffnew[:N, N] = cross # No noise, just jitter for numerical stability Kffnew[N, N] = end + self.jitter # Heuristic to avoid adding degenerate points if Kffnew.logdet() > -15.0: outside_vars["Kff"] = Kffnew outside_vars["N"] += 1 outside_vars["X"] =, xnew)) outside_vars["y"] =, ynew)) return ynew return lambda xnew: sample_next(xnew, outside_vars)