Source code for

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints
from torch.nn import Parameter

import pyro
import pyro.distributions as dist
from import GPModel
from import conditional
from pyro.distributions.util import eye_like
from pyro.nn.module import PyroParam, pyro_method

[docs]class VariationalGP(GPModel): r""" Variational Gaussian Process model. This model deals with both Gaussian and non-Gaussian likelihoods. Given inputs\ :math:`X` and their noisy observations :math:`y`, the model takes the form .. math:: f &\sim \mathcal{GP}(0, k(X, X)),\\ y & \sim p(y) = p(y \mid f) p(f), where :math:`p(y \mid f)` is the likelihood. We will use a variational approach in this model by approximating :math:`q(f)` to the posterior :math:`p(f\mid y)`. Precisely, :math:`q(f)` will be a multivariate normal distribution with two parameters ``f_loc`` and ``f_scale_tril``, which will be learned during a variational inference process. .. note:: This model can be seen as a special version of :class:`.SparseVariationalGP` model with :math:`X_u = X`. .. note:: This model has :math:`\mathcal{O}(N^3)` complexity for training, :math:`\mathcal{O}(N^3)` complexity for testing. Here, :math:`N` is the number of train inputs. Size of variational parameters is :math:`\mathcal{O}(N^2)`. :param torch.Tensor X: A input data for training. Its first dimension is the number of data points. :param torch.Tensor y: An output data for training. Its last dimension is the number of data points. :param kernel: A Pyro kernel object, which is the covariance function :math:`k`. :param Likelihood likelihood: A likelihood object. :param callable mean_function: An optional mean function :math:`m` of this Gaussian process. By default, we use zero mean. :param torch.Size latent_shape: Shape for latent processes (`batch_shape` of :math:`q(f)`). By default, it equals to output batch shape ``y.shape[:-1]``. For the multi-class classification problems, ``latent_shape[-1]`` should corresponse to the number of classes. :param bool whiten: A flag to tell if variational parameters ``f_loc`` and ``f_scale_tril`` are transformed by the inverse of ``Lff``, where ``Lff`` is the lower triangular decomposition of :math:`kernel(X, X)`. Enable this flag will help optimization. :param float jitter: A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition. """ def __init__( self, X, y, kernel, likelihood, mean_function=None, latent_shape=None, whiten=False, jitter=1e-6, ): super().__init__(X, y, kernel, mean_function, jitter) self.likelihood = likelihood y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size([]) self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape N = self.X.size(0) f_loc = self.X.new_zeros(self.latent_shape + (N,)) self.f_loc = Parameter(f_loc) identity = eye_like(self.X, N) f_scale_tril = identity.repeat(self.latent_shape + (1, 1)) self.f_scale_tril = PyroParam(f_scale_tril, constraints.lower_cholesky) self.whiten = whiten self._sample_latent = True
[docs] @pyro_method def model(self): self.set_mode("model") N = self.X.size(0) Kff = self.kernel(self.X).contiguous() Kff.view(-1)[:: N + 1] += self.jitter # add jitter to the diagonal Lff = torch.linalg.cholesky(Kff) zero_loc = self.X.new_zeros(self.f_loc.shape) if self.whiten: identity = eye_like(self.X, N) pyro.sample( self._pyro_get_fullname("f"), dist.MultivariateNormal(zero_loc, scale_tril=identity).to_event( zero_loc.dim() - 1 ), ) f_scale_tril = Lff.matmul(self.f_scale_tril) f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1) else: pyro.sample( self._pyro_get_fullname("f"), dist.MultivariateNormal(zero_loc, scale_tril=Lff).to_event( zero_loc.dim() - 1 ), ) f_scale_tril = self.f_scale_tril f_loc = self.f_loc f_loc = f_loc + self.mean_function(self.X) f_var = f_scale_tril.pow(2).sum(dim=-1) if self.y is None: return f_loc, f_var else: return self.likelihood(f_loc, f_var, self.y)
[docs] @pyro_method def guide(self): self.set_mode("guide") self._load_pyro_samples() pyro.sample( self._pyro_get_fullname("f"), dist.MultivariateNormal(self.f_loc, scale_tril=self.f_scale_tril).to_event( self.f_loc.dim() - 1 ), )
[docs] def forward(self, Xnew, full_cov=False): r""" Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data :math:`X_{new}`: .. math:: p(f^* \mid X_{new}, X, y, k, f_{loc}, f_{scale\_tril}) = \mathcal{N}(loc, cov). .. note:: Variational parameters ``f_loc``, ``f_scale_tril``, together with kernel's parameters have been learned from a training procedure (MCMC or SVI). :param torch.Tensor Xnew: A input data for testing. Note that ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``. :param bool full_cov: A flag to decide if we want to predict full covariance matrix or just variance. :returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))` :rtype: tuple(torch.Tensor, torch.Tensor) """ self._check_Xnew_shape(Xnew) self.set_mode("guide") loc, cov = conditional( Xnew, self.X, self.kernel, self.f_loc, self.f_scale_tril, full_cov=full_cov, whiten=self.whiten, jitter=self.jitter, ) return loc + self.mean_function(Xnew), cov