# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
from pyro.infer import TraceMeanField_ELBO
from pyro.infer.util import torch_backward, torch_item
[docs]def conditional(
Xnew,
X,
kernel,
f_loc,
f_scale_tril=None,
Lff=None,
full_cov=False,
whiten=False,
jitter=1e-6,
):
r"""
Given :math:`X_{new}`, predicts loc and covariance matrix of the conditional
multivariate normal distribution
.. math:: p(f^*(X_{new}) \mid X, k, f_{loc}, f_{scale\_tril}).
Here ``f_loc`` and ``f_scale_tril`` are variation parameters of the variational
distribution
.. math:: q(f \mid f_{loc}, f_{scale\_tril}) \sim p(f | X, y),
where :math:`f` is the function value of the Gaussian Process given input :math:`X`
.. math:: p(f(X)) \sim \mathcal{N}(0, k(X, X))
and :math:`y` is computed from :math:`f` by some likelihood function
:math:`p(y|f)`.
In case ``f_scale_tril=None``, we consider :math:`f = f_{loc}` and computes
.. math:: p(f^*(X_{new}) \mid X, k, f).
In case ``f_scale_tril`` is not ``None``, we follow the derivation from reference
[1]. For the case ``f_scale_tril=None``, we follow the popular reference [2].
References:
[1] `Sparse GPs: approximate the posterior, not the model
<https://www.prowler.io/sparse-gps-approximate-the-posterior-not-the-model/>`_
[2] `Gaussian Processes for Machine Learning`,
Carl E. Rasmussen, Christopher K. I. Williams
:param torch.Tensor Xnew: A new input data.
:param torch.Tensor X: An input data to be conditioned on.
:param ~pyro.contrib.gp.kernels.kernel.Kernel kernel: A Pyro kernel object.
:param torch.Tensor f_loc: Mean of :math:`q(f)`. In case ``f_scale_tril=None``,
:math:`f_{loc} = f`.
:param torch.Tensor f_scale_tril: Lower triangular decomposition of covariance
matrix of :math:`q(f)`'s .
:param torch.Tensor Lff: Lower triangular decomposition of :math:`kernel(X, X)`
(optional).
:param bool full_cov: A flag to decide if we want to return full covariance
matrix or just variance.
:param bool whiten: A flag to tell if ``f_loc`` and ``f_scale_tril`` are
already transformed by the inverse of ``Lff``.
:param float jitter: A small positive term which is added into the diagonal part of
a covariance matrix to help stablize its Cholesky decomposition.
:returns: loc and covariance matrix (or variance) of :math:`p(f^*(X_{new}))`
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
# p(f* | Xnew, X, kernel, f_loc, f_scale_tril) ~ N(f* | loc, cov)
# Kff = Lff @ Lff.T
# v = inv(Lff) @ f_loc <- whitened f_loc
# S = inv(Lff) @ f_scale_tril <- whitened f_scale_tril
# Denote:
# W = (inv(Lff) @ Kf*).T
# K = W @ S @ S.T @ W.T
# Q** = K*f @ inv(Kff) @ Kf* = W @ W.T
# loc = K*f @ inv(Kff) @ f_loc = W @ v
# Case 1: f_scale_tril = None
# cov = K** - K*f @ inv(Kff) @ Kf* = K** - Q**
# Case 2: f_scale_tril != None
# cov = K** - Q** + K*f @ inv(Kff) @ f_cov @ inv(Kff) @ Kf*
# = K** - Q** + W @ S @ S.T @ W.T
# = K** - Q** + K
N = X.size(0)
M = Xnew.size(0)
latent_shape = f_loc.shape[:-1]
if Lff is None:
Kff = kernel(X).contiguous()
Kff.view(-1)[:: N + 1] += jitter # add jitter to diagonal
Lff = torch.linalg.cholesky(Kff)
Kfs = kernel(X, Xnew)
# convert f_loc_shape from latent_shape x N to N x latent_shape
f_loc = f_loc.permute(-1, *range(len(latent_shape)))
# convert f_loc to 2D tensor for packing
f_loc_2D = f_loc.reshape(N, -1)
if f_scale_tril is not None:
# convert f_scale_tril_shape from latent_shape x N x N to N x N x latent_shape
f_scale_tril = f_scale_tril.permute(-2, -1, *range(len(latent_shape)))
# convert f_scale_tril to 2D tensor for packing
f_scale_tril_2D = f_scale_tril.reshape(N, -1)
if whiten:
v_2D = f_loc_2D
W = torch.linalg.solve_triangular(Lff, Kfs, upper=False).t()
if f_scale_tril is not None:
S_2D = f_scale_tril_2D
else:
pack = torch.cat((f_loc_2D, Kfs), dim=1)
if f_scale_tril is not None:
pack = torch.cat((pack, f_scale_tril_2D), dim=1)
Lffinv_pack = torch.linalg.solve_triangular(Lff, pack, upper=False)
# unpack
v_2D = Lffinv_pack[:, : f_loc_2D.size(1)]
W = Lffinv_pack[:, f_loc_2D.size(1) : f_loc_2D.size(1) + M].t()
if f_scale_tril is not None:
S_2D = Lffinv_pack[:, -f_scale_tril_2D.size(1) :]
loc_shape = latent_shape + (M,)
loc = W.matmul(v_2D).t().reshape(loc_shape)
if full_cov:
Kss = kernel(Xnew)
Qss = W.matmul(W.t())
cov = Kss - Qss
else:
Kssdiag = kernel(Xnew, diag=True)
Qssdiag = W.pow(2).sum(dim=-1)
# Theoretically, Kss - Qss is non-negative; but due to numerical
# computation, that might not be the case in practice.
var = (Kssdiag - Qssdiag).clamp(min=0)
if f_scale_tril is not None:
W_S_shape = (Xnew.size(0),) + f_scale_tril.shape[1:]
W_S = W.matmul(S_2D).reshape(W_S_shape)
# convert W_S_shape from M x N x latent_shape to latent_shape x M x N
W_S = W_S.permute(list(range(2, W_S.dim())) + [0, 1])
if full_cov:
St_Wt = W_S.transpose(-2, -1)
K = W_S.matmul(St_Wt)
cov = cov + K
else:
Kdiag = W_S.pow(2).sum(dim=-1)
var = var + Kdiag
else:
if full_cov:
cov = cov.expand(latent_shape + (M, M))
else:
var = var.expand(latent_shape + (M,))
return (loc, cov) if full_cov else (loc, var)
[docs]def train(gpmodule, optimizer=None, loss_fn=None, retain_graph=None, num_steps=1000):
"""
A helper to optimize parameters for a GP module.
:param ~pyro.contrib.gp.models.GPModel gpmodule: A GP module.
:param ~torch.optim.Optimizer optimizer: A PyTorch optimizer instance.
By default, we use Adam with ``lr=0.01``.
:param callable loss_fn: A loss function which takes inputs are
``gpmodule.model``, ``gpmodule.guide``, and returns ELBO loss.
By default, ``loss_fn=TraceMeanField_ELBO().differentiable_loss``.
:param bool retain_graph: An optional flag of ``torch.autograd.backward``.
:param int num_steps: Number of steps to run SVI.
:returns: a list of losses during the training procedure
:rtype: list
"""
optimizer = (
torch.optim.Adam(gpmodule.parameters(), lr=0.01)
if optimizer is None
else optimizer
)
# TODO: add support for JIT loss
loss_fn = TraceMeanField_ELBO().differentiable_loss if loss_fn is None else loss_fn
def closure():
optimizer.zero_grad()
loss = loss_fn(gpmodule.model, gpmodule.guide)
torch_backward(loss, retain_graph)
return loss
losses = []
for i in range(num_steps):
loss = optimizer.step(closure)
losses.append(torch_item(loss))
return losses