# Source code for pyro.contrib.gp.models.vgp

from __future__ import absolute_import, division, print_function

import torch
from torch.distributions import constraints
from torch.nn import Parameter

import pyro
import pyro.distributions as dist
from pyro.contrib import autoname
from pyro.contrib.gp.models.model import GPModel
from pyro.contrib.gp.util import conditional
from pyro.distributions.util import eye_like

[docs]class VariationalGP(GPModel):
r"""
Variational Gaussian Process model.

This model deals with both Gaussian and non-Gaussian likelihoods. Given inputs\
:math:X and their noisy observations :math:y, the model takes the form

.. math::
f &\sim \mathcal{GP}(0, k(X, X)),\\
y & \sim p(y) = p(y \mid f) p(f),

where :math:p(y \mid f) is the likelihood.

We will use a variational approach in this model by approximating :math:q(f) to
the posterior :math:p(f\mid y). Precisely, :math:q(f) will be a multivariate
normal distribution with two parameters f_loc and f_scale_tril, which will
be learned during a variational inference process.

.. note:: This model can be seen as a special version of
:class:.SparseVariationalGP model with :math:X_u = X.

.. note:: This model has :math:\mathcal{O}(N^3) complexity for training,
:math:\mathcal{O}(N^3) complexity for testing. Here, :math:N is the number
of train inputs. Size of variational parameters is :math:\mathcal{O}(N^2).

:param torch.Tensor X: A input data for training. Its first dimension is the number
of data points.
:param torch.Tensor y: An output data for training. Its last dimension is the
number of data points.
:param ~pyro.contrib.gp.kernels.kernel.Kernel kernel: A Pyro kernel object, which
is the covariance function :math:k.
:param ~pyro.contrib.gp.likelihoods.likelihood Likelihood likelihood: A likelihood
object.
:param callable mean_function: An optional mean function :math:m of this Gaussian
process. By default, we use zero mean.
:param torch.Size latent_shape: Shape for latent processes (batch_shape of
:math:q(f)). By default, it equals to output batch shape y.shape[:-1].
For the multi-class classification problems, latent_shape[-1] should
corresponse to the number of classes.
:param bool whiten: A flag to tell if variational parameters f_loc and
f_scale_tril are transformed by the inverse of Lff, where Lff is
the lower triangular decomposition of :math:kernel(X, X). Enable this flag
will help optimization.
:param float jitter: A small positive term which is added into the diagonal part of
a covariance matrix to help stablize its Cholesky decomposition.
"""
def __init__(self, X, y, kernel, likelihood, mean_function=None,
latent_shape=None, whiten=False, jitter=1e-6):
super(VariationalGP, self).__init__(X, y, kernel, mean_function, jitter)

self.likelihood = likelihood

y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size([])
self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape

N = self.X.size(0)
f_loc = self.X.new_zeros(self.latent_shape + (N,))
self.f_loc = Parameter(f_loc)

identity = eye_like(self.X, N)
f_scale_tril = identity.repeat(self.latent_shape + (1, 1))
self.f_scale_tril = Parameter(f_scale_tril)
self.set_constraint("f_scale_tril", constraints.lower_cholesky)

self.whiten = whiten
self._sample_latent = True

[docs]    @autoname.scope(prefix="VGP")
def model(self):
self.set_mode("model")

N = self.X.size(0)
Kff = self.kernel(self.X).contiguous()
Kff.view(-1)[::N + 1] += self.jitter  # add jitter to the diagonal
Lff = Kff.cholesky()

zero_loc = self.X.new_zeros(self.f_loc.shape)
if self.whiten:
identity = eye_like(self.X, N)
pyro.sample("f",
dist.MultivariateNormal(zero_loc, scale_tril=identity)
.to_event(zero_loc.dim() - 1))
f_scale_tril = Lff.matmul(self.f_scale_tril)
f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1)
else:
pyro.sample("f",
dist.MultivariateNormal(zero_loc, scale_tril=Lff)
.to_event(zero_loc.dim() - 1))
f_scale_tril = self.f_scale_tril
f_loc = self.f_loc

f_loc = f_loc + self.mean_function(self.X)
f_var = f_scale_tril.pow(2).sum(dim=-1)
if self.y is None:
return f_loc, f_var
else:
return self.likelihood(f_loc, f_var, self.y)

[docs]    @autoname.scope(prefix="VGP")
def guide(self):
self.set_mode("guide")

pyro.sample("f",
dist.MultivariateNormal(self.f_loc, scale_tril=self.f_scale_tril)
.to_event(self.f_loc.dim()-1))

[docs]    def forward(self, Xnew, full_cov=False):
r"""
Computes the mean and covariance matrix (or variance) of Gaussian Process
posterior on a test input data :math:X_{new}:

.. math:: p(f^* \mid X_{new}, X, y, k, f_{loc}, f_{scale\_tril})
= \mathcal{N}(loc, cov).

.. note:: Variational parameters f_loc, f_scale_tril, together with
kernel's parameters have been learned from a training procedure (MCMC or
SVI).

:param torch.Tensor Xnew: A input data for testing. Note that
Xnew.shape[1:] must be the same as self.X.shape[1:].
:param bool full_cov: A flag to decide if we want to predict full covariance
matrix or just variance.
:returns: loc and covariance matrix (or variance) of :math:p(f^*(X_{new}))
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
self._check_Xnew_shape(Xnew)
self.set_mode("guide")

loc, cov = conditional(Xnew, self.X, self.kernel, self.f_loc, self.f_scale_tril,
full_cov=full_cov, whiten=self.whiten, jitter=self.jitter)
return loc + self.mean_function(Xnew), cov