# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
from torch.distributions import constraints
from torch.nn import Parameter
import pyro
import pyro.distributions as dist
from pyro.contrib.gp.models.model import GPModel
from pyro.contrib.gp.util import conditional
from pyro.distributions.util import eye_like
from pyro.nn.module import PyroParam, pyro_method
[docs]class VariationalGP(GPModel):
r"""
Variational Gaussian Process model.
This model deals with both Gaussian and non-Gaussian likelihoods. Given inputs\
:math:`X` and their noisy observations :math:`y`, the model takes the form
.. math::
f &\sim \mathcal{GP}(0, k(X, X)),\\
y & \sim p(y) = p(y \mid f) p(f),
where :math:`p(y \mid f)` is the likelihood.
We will use a variational approach in this model by approximating :math:`q(f)` to
the posterior :math:`p(f\mid y)`. Precisely, :math:`q(f)` will be a multivariate
normal distribution with two parameters ``f_loc`` and ``f_scale_tril``, which will
be learned during a variational inference process.
.. note:: This model can be seen as a special version of
:class:`.SparseVariationalGP` model with :math:`X_u = X`.
.. note:: This model has :math:`\mathcal{O}(N^3)` complexity for training,
:math:`\mathcal{O}(N^3)` complexity for testing. Here, :math:`N` is the number
of train inputs. Size of variational parameters is :math:`\mathcal{O}(N^2)`.
:param torch.Tensor X: A input data for training. Its first dimension is the number
of data points.
:param torch.Tensor y: An output data for training. Its last dimension is the
number of data points.
:param ~pyro.contrib.gp.kernels.kernel.Kernel kernel: A Pyro kernel object, which
is the covariance function :math:`k`.
:param ~pyro.contrib.gp.likelihoods.likelihood Likelihood likelihood: A likelihood
object.
:param callable mean_function: An optional mean function :math:`m` of this Gaussian
process. By default, we use zero mean.
:param torch.Size latent_shape: Shape for latent processes (`batch_shape` of
:math:`q(f)`). By default, it equals to output batch shape ``y.shape[:-1]``.
For the multi-class classification problems, ``latent_shape[-1]`` should
corresponse to the number of classes.
:param bool whiten: A flag to tell if variational parameters ``f_loc`` and
``f_scale_tril`` are transformed by the inverse of ``Lff``, where ``Lff`` is
the lower triangular decomposition of :math:`kernel(X, X)`. Enable this flag
will help optimization.
:param float jitter: A small positive term which is added into the diagonal part of
a covariance matrix to help stablize its Cholesky decomposition.
"""
def __init__(
self,
X,
y,
kernel,
likelihood,
mean_function=None,
latent_shape=None,
whiten=False,
jitter=1e-6,
):
assert isinstance(
X, torch.Tensor
), "X needs to be a torch Tensor instead of a {}".format(type(X))
if y is not None:
assert isinstance(
y, torch.Tensor
), "y needs to be a torch Tensor instead of a {}".format(type(y))
super().__init__(X, y, kernel, mean_function, jitter)
self.likelihood = likelihood
y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size([])
self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape
N = self.X.size(0)
f_loc = self.X.new_zeros(self.latent_shape + (N,))
self.f_loc = Parameter(f_loc)
identity = eye_like(self.X, N)
f_scale_tril = identity.repeat(self.latent_shape + (1, 1))
self.f_scale_tril = PyroParam(f_scale_tril, constraints.lower_cholesky)
self.whiten = whiten
self._sample_latent = True
[docs] @pyro_method
def model(self):
self.set_mode("model")
N = self.X.size(0)
Kff = self.kernel(self.X).contiguous()
Kff.view(-1)[:: N + 1] += self.jitter # add jitter to the diagonal
Lff = torch.linalg.cholesky(Kff)
zero_loc = self.X.new_zeros(self.f_loc.shape)
if self.whiten:
identity = eye_like(self.X, N)
pyro.sample(
self._pyro_get_fullname("f"),
dist.MultivariateNormal(zero_loc, scale_tril=identity).to_event(
zero_loc.dim() - 1
),
)
f_scale_tril = Lff.matmul(self.f_scale_tril)
f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1)
else:
pyro.sample(
self._pyro_get_fullname("f"),
dist.MultivariateNormal(zero_loc, scale_tril=Lff).to_event(
zero_loc.dim() - 1
),
)
f_scale_tril = self.f_scale_tril
f_loc = self.f_loc
f_loc = f_loc + self.mean_function(self.X)
f_var = f_scale_tril.pow(2).sum(dim=-1)
if self.y is None:
return f_loc, f_var
else:
return self.likelihood(f_loc, f_var, self.y)
[docs] @pyro_method
def guide(self):
self.set_mode("guide")
self._load_pyro_samples()
pyro.sample(
self._pyro_get_fullname("f"),
dist.MultivariateNormal(self.f_loc, scale_tril=self.f_scale_tril).to_event(
self.f_loc.dim() - 1
),
)
[docs] def forward(self, Xnew, full_cov=False):
r"""
Computes the mean and covariance matrix (or variance) of Gaussian Process
posterior on a test input data :math:`X_{new}`:
.. math:: p(f^* \mid X_{new}, X, y, k, f_{loc}, f_{scale\_tril})
= \mathcal{N}(loc, cov).
.. note:: Variational parameters ``f_loc``, ``f_scale_tril``, together with
kernel's parameters have been learned from a training procedure (MCMC or
SVI).
:param torch.Tensor Xnew: A input data for testing. Note that
``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``.
:param bool full_cov: A flag to decide if we want to predict full covariance
matrix or just variance.
:returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))`
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
self._check_Xnew_shape(Xnew)
self.set_mode("guide")
loc, cov = conditional(
Xnew,
self.X,
self.kernel,
self.f_loc,
self.f_scale_tril,
full_cov=full_cov,
whiten=self.whiten,
jitter=self.jitter,
)
return loc + self.mean_function(Xnew), cov