Source code for pyro.contrib.gp.models.sgpr

from __future__ import absolute_import, division, print_function

import torch
from torch.distributions import constraints
from torch.nn import Parameter

import pyro
import pyro.distributions as dist
from pyro.contrib.gp.models.model import GPModel
from pyro.distributions.util import matrix_triangular_solve_compat
from pyro.params import param_with_module_name


[docs]class SparseGPRegression(GPModel): u""" Sparse Gaussian Process Regression model. In :class:`.GPRegression` model, when the number of input data :math:`X` is large, the covariance matrix :math:`k(X, X)` will require a lot of computational steps to compute its inverse (for log likelihood and for prediction). By introducing an additional inducing-input parameter :math:`X_u`, we can reduce computational cost by approximate :math:`k(X, X)` by a low-rank Nymstr\u00F6m approximation :math:`Q` (see reference [1]), where .. math:: Q = k(X, X_u) k(X,X)^{-1} k(X_u, X). Given inputs :math:`X`, their noisy observations :math:`y`, and the inducing-input parameters :math:`X_u`, the model takes the form: .. math:: u & \sim \mathcal{GP}(0, k(X_u, X_u)),\\\\ f & \sim q(f \mid X, X_u) = \mathbb{E}_{p(u)}q(f\mid X, X_u, u),\\\\ y & \sim f + \epsilon, where :math:`\epsilon` is Gaussian noise and the conditional distribution :math:`q(f\mid X, X_u, u)` is an approximation of .. math:: p(f\mid X, X_u, u) = \mathcal{N}(m, k(X, X) - Q), whose terms :math:`m` and :math:`k(X, X) - Q` is derived from the joint multivariate normal distribution: .. math:: [f, u] \sim \mathcal{GP}(0, k([X, X_u], [X, X_u])). This class implements three approximation methods: + Deterministic Training Conditional (DTC): .. math:: q(f\mid X, X_u, u) = \mathcal{N}(m, 0), which in turns will imply .. math:: f \sim \mathcal{N}(0, Q). + Fully Independent Training Conditional (FITC): .. math:: q(f\mid X, X_u, u) = \mathcal{N}(m, diag(k(X, X) - Q)), which in turns will correct the diagonal part of the approximation in DTC: .. math:: f \sim \mathcal{N}(0, Q + diag(k(X, X) - Q)). + Variational Free Energy (VFE), which is similar to DTC but has an additional `trace_term` in the model's log likelihood. This additional term makes "VFE" equivalent to the variational approach in :class:`.SparseVariationalGP` (see reference [2]). .. note:: This model has :math:`\mathcal{O}(NM^2)` complexity for training, :math:`\mathcal{O}(NM^2)` complexity for testing. Here, :math:`N` is the number of train inputs, :math:`M` is the number of inducing inputs. References: [1] `A Unifying View of Sparse Approximate Gaussian Process Regression`, Joaquin Qui\u00F1onero-Candela, Carl E. Rasmussen [2] `Variational learning of inducing variables in sparse Gaussian processes`, Michalis Titsias :param torch.Tensor X: A input data for training. Its first dimension is the number of data points. :param torch.Tensor y: An output data for training. Its last dimension is the number of data points. :param ~pyro.contrib.gp.kernels.kernel.Kernel kernel: A Pyro kernel object, which is the covariance function :math:`k`. :param torch.Tensor Xu: Initial values for inducing points, which are parameters of our model. :param torch.Tensor noise: Variance of Gaussian noise of this model. :param callable mean_function: An optional mean function :math:`m` of this Gaussian process. By default, we use zero mean. :param str approx: One of approximation methods: "DTC", "FITC", and "VFE" (default). :param float jitter: A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition. :param str name: Name of this model. """ def __init__(self, X, y, kernel, Xu, noise=None, mean_function=None, approx=None, jitter=1e-6, name="SGPR"): super(SparseGPRegression, self).__init__(X, y, kernel, mean_function, jitter, name) self.Xu = Parameter(Xu) noise = self.X.new_ones(()) if noise is None else noise self.noise = Parameter(noise) self.set_constraint("noise", constraints.greater_than(self.jitter)) if approx is None: self.approx = "VFE" elif approx in ["DTC", "FITC", "VFE"]: self.approx = approx else: raise ValueError("The sparse approximation method should be one of " "'DTC', 'FITC', 'VFE'.")
[docs] def model(self): self.set_mode("model") Xu = self.get_param("Xu") noise = self.get_param("noise") # W = inv(Luu) @ Kuf # Qff = Kfu @ inv(Kuu) @ Kuf = W.T @ W # Fomulas for each approximation method are # DTC: y_cov = Qff + noise, trace_term = 0 # FITC: y_cov = Qff + diag(Kff - Qff) + noise, trace_term = 0 # VFE: y_cov = Qff + noise, trace_term = tr(Kff-Qff) / noise # y_cov = W.T @ W + D # trace_term is added into log_prob M = Xu.shape[0] Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M, M)) * self.jitter Luu = Kuu.potrf(upper=False) Kuf = self.kernel(Xu, self.X) W = matrix_triangular_solve_compat(Kuf, Luu, upper=False) D = noise.expand(W.shape[1]) trace_term = 0 if self.approx == "FITC" or self.approx == "VFE": Kffdiag = self.kernel(self.X, diag=True) Qffdiag = W.pow(2).sum(dim=0) if self.approx == "FITC": D = D + Kffdiag - Qffdiag else: # approx = "VFE" trace_term += (Kffdiag - Qffdiag).sum() / noise zero_loc = self.X.new_zeros(self.X.shape[0]) f_loc = zero_loc + self.mean_function(self.X) if self.y is None: f_var = D + W.pow(2).sum(dim=0) return f_loc, f_var else: y_name = param_with_module_name(self.name, "y") return pyro.sample(y_name, dist.LowRankMultivariateNormal(f_loc, W, D, trace_term) .expand_by(self.y.shape[:-1]) .independent(self.y.dim() - 1), obs=self.y)
[docs] def guide(self): self.set_mode("guide") Xu = self.get_param("Xu") noise = self.get_param("noise") return Xu, noise
[docs] def forward(self, Xnew, full_cov=False, noiseless=True): r""" Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data :math:`X_{new}`: .. math:: p(f^* \mid X_{new}, X, y, k, X_u, \epsilon) = \mathcal{N}(loc, cov). .. note:: The noise parameter ``noise`` (:math:`\epsilon`), the inducing-point parameter ``Xu``, together with kernel's parameters have been learned from a training procedure (MCMC or SVI). :param torch.Tensor Xnew: A input data for testing. Note that ``Xnew.shape[1:]`` must be the same as ``self.X.shape[1:]``. :param bool full_cov: A flag to decide if we want to predict full covariance matrix or just variance. :param bool noiseless: A flag to decide if we want to include noise in the prediction output or not. :returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))` :rtype: tuple(torch.Tensor, torch.Tensor) """ self._check_Xnew_shape(Xnew) Xu, noise = self.guide() # W = inv(Luu) @ Kuf # Ws = inv(Luu) @ Kus # D as in self.model() # K = I + W @ inv(D) @ W.T = L @ L.T # S = inv[Kuu + Kuf @ inv(D) @ Kfu] # = inv(Luu).T @ inv[I + inv(Luu)@ Kuf @ inv(D)@ Kfu @ inv(Luu).T] @ inv(Luu) # = inv(Luu).T @ inv[I + W @ inv(D) @ W.T] @ inv(Luu) # = inv(Luu).T @ inv(K) @ inv(Luu) # = inv(Luu).T @ inv(L).T @ inv(L) @ inv(Luu) # loc = Ksu @ S @ Kuf @ inv(D) @ y = Ws.T @ inv(L).T @ inv(L) @ W @ inv(D) @ y # cov = Kss - Ksu @ inv(Kuu) @ Kus + Ksu @ S @ Kus # = kss - Ksu @ inv(Kuu) @ Kus + Ws.T @ inv(L).T @ inv(L) @ Ws N = self.X.shape[0] M = Xu.shape[0] Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M, M)) * self.jitter Luu = Kuu.potrf(upper=False) Kus = self.kernel(Xu, Xnew) Kuf = self.kernel(Xu, self.X) W = matrix_triangular_solve_compat(Kuf, Luu, upper=False) Ws = matrix_triangular_solve_compat(Kus, Luu, upper=False) D = noise.expand(N) if self.approx == "FITC": Kffdiag = self.kernel(self.X, diag=True) Qffdiag = W.pow(2).sum(dim=0) D = D + Kffdiag - Qffdiag W_Dinv = W / D Id = torch.eye(M, M, out=W.new_empty(M, M)) K = Id + W_Dinv.matmul(W.t()) L = K.potrf(upper=False) # get y_residual and convert it into 2D tensor for packing y_residual = self.y - self.mean_function(self.X) y_2D = y_residual.reshape(-1, N).t() W_Dinv_y = W_Dinv.matmul(y_2D) pack = torch.cat((W_Dinv_y, Ws), dim=1) Linv_pack = matrix_triangular_solve_compat(pack, L, upper=False) # unpack Linv_W_Dinv_y = Linv_pack[:, :W_Dinv_y.shape[1]] Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1]:] loc_shape = self.y.shape[:-1] + (Xnew.shape[0],) loc = Linv_W_Dinv_y.t().matmul(Linv_Ws).reshape(loc_shape) if full_cov: Kss = self.kernel(Xnew) if not noiseless: Kss = Kss + noise.expand(Xnew.shape[0]).diag() Qss = Ws.t().matmul(Ws) cov = Kss - Qss + Linv_Ws.t().matmul(Linv_Ws) else: Kssdiag = self.kernel(Xnew, diag=True) if not noiseless: Kssdiag = Kssdiag + noise.expand(Xnew.shape[0]) Qssdiag = Ws.pow(2).sum(dim=0) cov = Kssdiag - Qssdiag + Linv_Ws.pow(2).sum(dim=0) cov_shape = self.y.shape[:-1] + (Xnew.shape[0], Xnew.shape[0]) cov = cov.expand(cov_shape) return loc + self.mean_function(Xnew), cov