Source code for pyro.contrib.gp.models.vgp

from __future__ import absolute_import, division, print_function

import torch
from torch.distributions import constraints
from torch.nn import Parameter

import pyro
import pyro.distributions as dist
from pyro.contrib.gp.models.model import GPModel
from pyro.contrib.gp.util import conditional
from pyro.params import param_with_module_name


[docs]class VariationalGP(GPModel): r""" Variational Gaussian Process model. This model deals with both Gaussian and non-Gaussian likelihoods. Given inputs\ :math:`X` and their noisy observations :math:`y`, the model takes the form .. math:: f &\sim \mathcal{GP}(0, k(X, X)),\\ y & \sim p(y) = p(y \mid f) p(f), where :math:`p(y \mid f)` is the likelihood. We will use a variational approach in this model by approximating :math:`q(f)` to the posterior :math:`p(f\mid y)`. Precisely, :math:`q(f)` will be a multivariate normal distribution with two parameters ``f_loc`` and ``f_scale_tril``, which will be learned during a variational inference process. .. note:: This model can be seen as a special version of :class:`.SparseVariationalGP` model with :math:`X_u = X`. .. note:: This model has :math:`\mathcal{O}(N^3)` complexity for training, :math:`\mathcal{O}(N^3)` complexity for testing. Here, :math:`N` is the number of train inputs. Size of variational parameters is :math:`\mathcal{O}(N^2)`. :param torch.Tensor X: A input data for training. Its first dimension is the number of data points. :param torch.Tensor y: An output data for training. Its last dimension is the number of data points. :param ~pyro.contrib.gp.kernels.kernel.Kernel kernel: A Pyro kernel object, which is the covariance function :math:`k`. :param ~pyro.contrib.gp.likelihoods.likelihood Likelihood likelihood: A likelihood object. :param callable mean_function: An optional mean function :math:`m` of this Gaussian process. By default, we use zero mean. :param torch.Size latent_shape: Shape for latent processes (`batch_shape` of :math:`q(f)`). By default, it equals to output batch shape ``y.shape[:-1]``. For the multi-class classification problems, ``latent_shape[-1]`` should corresponse to the number of classes. :param bool whiten: A flag to tell if variational parameters ``f_loc`` and ``f_scale_tril`` are transformed by the inverse of ``Lff``, where ``Lff`` is the lower triangular decomposition of :math:`kernel(X, X)`. Enable this flag will help optimization. :param float jitter: A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition. :param str name: Name of this model. """ def __init__(self, X, y, kernel, likelihood, mean_function=None, latent_shape=None, whiten=False, jitter=1e-6, name="VGP"): super(VariationalGP, self).__init__(X, y, kernel, mean_function, jitter, name) self.likelihood = likelihood self.whiten = whiten y_batch_shape = self.y.shape[:-1] if self.y is not None else torch.Size([]) self.latent_shape = latent_shape if latent_shape is not None else y_batch_shape N = self.X.shape[0] f_loc_shape = self.latent_shape + (N,) f_loc = self.X.new_zeros(f_loc_shape) self.f_loc = Parameter(f_loc) f_scale_tril_shape = self.latent_shape + (N, N) Id = torch.eye(N, out=self.X.new_empty(N, N)) f_scale_tril = Id.expand(f_scale_tril_shape) self.f_scale_tril = Parameter(f_scale_tril) self.set_constraint("f_scale_tril", constraints.lower_cholesky) self._sample_latent = True
[docs] def model(self): self.set_mode("model") f_loc = self.get_param("f_loc") f_scale_tril = self.get_param("f_scale_tril") N = self.X.shape[0] Kff = self.kernel(self.X) + (torch.eye(N, out=self.X.new_empty(N, N)) * self.jitter) Lff = Kff.potrf(upper=False) zero_loc = self.X.new_zeros(f_loc.shape) f_name = param_with_module_name(self.name, "f") if self.whiten: Id = torch.eye(N, out=self.X.new_empty(N, N)) pyro.sample(f_name, dist.MultivariateNormal(zero_loc, scale_tril=Id) .independent(zero_loc.dim() - 1)) f_scale_tril = Lff.matmul(f_scale_tril) else: pyro.sample(f_name, dist.MultivariateNormal(zero_loc, scale_tril=Lff) .independent(zero_loc.dim() - 1)) f_var = f_scale_tril.pow(2).sum(dim=-1) if self.whiten: f_loc = Lff.matmul(f_loc.unsqueeze(-1)).squeeze(-1) f_loc = f_loc + self.mean_function(self.X) if self.y is None: return f_loc, f_var else: return self.likelihood(f_loc, f_var, self.y)
[docs] def guide(self): self.set_mode("guide") f_loc = self.get_param("f_loc") f_scale_tril = self.get_param("f_scale_tril") if self._sample_latent: f_name = param_with_module_name(self.name, "f") pyro.sample(f_name, dist.MultivariateNormal(f_loc, scale_tril=f_scale_tril) .independent(f_loc.dim()-1)) return f_loc, f_scale_tril
[docs] def forward(self, Xnew, full_cov=False): r""" Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data :math:`X_{new}`: .. math:: p(f^* \mid X_{new}, X, y, k, f_{loc}, f_{scale\_tril}) = \mathcal{N}(loc, cov). .. note:: Variational parameters ``f_loc``, ``f_scale_tril``, together with kernel's parameters have been learned from a training procedure (MCMC or SVI). :param torch.Tensor Xnew: A input data for testing. Note that ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``. :param bool full_cov: A flag to decide if we want to predict full covariance matrix or just variance. :returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))` :rtype: tuple(torch.Tensor, torch.Tensor) """ self._check_Xnew_shape(Xnew) # avoid sampling the unnecessary latent f self._sample_latent = False f_loc, f_scale_tril = self.guide() self._sample_latent = True loc, cov = conditional(Xnew, self.X, self.kernel, f_loc, f_scale_tril, full_cov=full_cov, whiten=self.whiten, jitter=self.jitter) return loc + self.mean_function(Xnew), cov