# Source code for pyro.contrib.gp.util

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import torch

from pyro.infer import TraceMeanField_ELBO
from pyro.infer.util import torch_backward, torch_item

[docs]def conditional(
Xnew,
X,
kernel,
f_loc,
f_scale_tril=None,
Lff=None,
full_cov=False,
whiten=False,
jitter=1e-6,
):
r"""
Given :math:X_{new}, predicts loc and covariance matrix of the conditional
multivariate normal distribution

.. math:: p(f^*(X_{new}) \mid X, k, f_{loc}, f_{scale\_tril}).

Here f_loc and f_scale_tril are variation parameters of the variational
distribution

.. math:: q(f \mid f_{loc}, f_{scale\_tril}) \sim p(f | X, y),

where :math:f is the function value of the Gaussian Process given input :math:X

.. math:: p(f(X)) \sim \mathcal{N}(0, k(X, X))

and :math:y is computed from :math:f by some likelihood function
:math:p(y|f).

In case f_scale_tril=None, we consider :math:f = f_{loc} and computes

.. math:: p(f^*(X_{new}) \mid X, k, f).

In case f_scale_tril is not None, we follow the derivation from reference
. For the case f_scale_tril=None, we follow the popular reference .

References:

 Sparse GPs: approximate the posterior, not the model
<https://www.prowler.io/sparse-gps-approximate-the-posterior-not-the-model/>_

 Gaussian Processes for Machine Learning,
Carl E. Rasmussen, Christopher K. I. Williams

:param torch.Tensor Xnew: A new input data.
:param torch.Tensor X: An input data to be conditioned on.
:param ~pyro.contrib.gp.kernels.kernel.Kernel kernel: A Pyro kernel object.
:param torch.Tensor f_loc: Mean of :math:q(f). In case f_scale_tril=None,
:math:f_{loc} = f.
:param torch.Tensor f_scale_tril: Lower triangular decomposition of covariance
matrix of :math:q(f)'s .
:param torch.Tensor Lff: Lower triangular decomposition of :math:kernel(X, X)
(optional).
:param bool full_cov: A flag to decide if we want to return full covariance
matrix or just variance.
:param bool whiten: A flag to tell if f_loc and f_scale_tril are
already transformed by the inverse of Lff.
:param float jitter: A small positive term which is added into the diagonal part of
a covariance matrix to help stablize its Cholesky decomposition.
:returns: loc and covariance matrix (or variance) of :math:p(f^*(X_{new}))
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
# p(f* | Xnew, X, kernel, f_loc, f_scale_tril) ~ N(f* | loc, cov)
# Kff = Lff @ Lff.T
# v = inv(Lff) @ f_loc  <- whitened f_loc
# S = inv(Lff) @ f_scale_tril  <- whitened f_scale_tril
# Denote:
#     W = (inv(Lff) @ Kf*).T
#     K = W @ S @ S.T @ W.T
#     Q** = K*f @ inv(Kff) @ Kf* = W @ W.T
# loc = K*f @ inv(Kff) @ f_loc = W @ v
# Case 1: f_scale_tril = None
#     cov = K** - K*f @ inv(Kff) @ Kf* = K** - Q**
# Case 2: f_scale_tril != None
#     cov = K** - Q** + K*f @ inv(Kff) @ f_cov @ inv(Kff) @ Kf*
#         = K** - Q** + W @ S @ S.T @ W.T
#         = K** - Q** + K

N = X.size(0)
M = Xnew.size(0)
latent_shape = f_loc.shape[:-1]

if Lff is None:
Kff = kernel(X).contiguous()
Kff.view(-1)[:: N + 1] += jitter  # add jitter to diagonal
Lff = torch.linalg.cholesky(Kff)
Kfs = kernel(X, Xnew)

# convert f_loc_shape from latent_shape x N to N x latent_shape
f_loc = f_loc.permute(-1, *range(len(latent_shape)))
# convert f_loc to 2D tensor for packing
f_loc_2D = f_loc.reshape(N, -1)
if f_scale_tril is not None:
# convert f_scale_tril_shape from latent_shape x N x N to N x N x latent_shape
f_scale_tril = f_scale_tril.permute(-2, -1, *range(len(latent_shape)))
# convert f_scale_tril to 2D tensor for packing
f_scale_tril_2D = f_scale_tril.reshape(N, -1)

if whiten:
v_2D = f_loc_2D
W = torch.linalg.solve_triangular(Lff, Kfs, upper=False).t()
if f_scale_tril is not None:
S_2D = f_scale_tril_2D
else:
pack = torch.cat((f_loc_2D, Kfs), dim=1)
if f_scale_tril is not None:
pack = torch.cat((pack, f_scale_tril_2D), dim=1)

Lffinv_pack = torch.linalg.solve_triangular(Lff, pack, upper=False)
# unpack
v_2D = Lffinv_pack[:, : f_loc_2D.size(1)]
W = Lffinv_pack[:, f_loc_2D.size(1) : f_loc_2D.size(1) + M].t()
if f_scale_tril is not None:
S_2D = Lffinv_pack[:, -f_scale_tril_2D.size(1) :]

loc_shape = latent_shape + (M,)
loc = W.matmul(v_2D).t().reshape(loc_shape)

if full_cov:
Kss = kernel(Xnew)
Qss = W.matmul(W.t())
cov = Kss - Qss
else:
Kssdiag = kernel(Xnew, diag=True)
Qssdiag = W.pow(2).sum(dim=-1)
# Theoretically, Kss - Qss is non-negative; but due to numerical
# computation, that might not be the case in practice.
var = (Kssdiag - Qssdiag).clamp(min=0)

if f_scale_tril is not None:
W_S_shape = (Xnew.size(0),) + f_scale_tril.shape[1:]
W_S = W.matmul(S_2D).reshape(W_S_shape)
# convert W_S_shape from M x N x latent_shape to latent_shape x M x N
W_S = W_S.permute(list(range(2, W_S.dim())) + [0, 1])

if full_cov:
St_Wt = W_S.transpose(-2, -1)
K = W_S.matmul(St_Wt)
cov = cov + K
else:
Kdiag = W_S.pow(2).sum(dim=-1)
var = var + Kdiag
else:
if full_cov:
cov = cov.expand(latent_shape + (M, M))
else:
var = var.expand(latent_shape + (M,))

return (loc, cov) if full_cov else (loc, var)

[docs]def train(gpmodule, optimizer=None, loss_fn=None, retain_graph=None, num_steps=1000):
"""
A helper to optimize parameters for a GP module.

:param ~pyro.contrib.gp.models.GPModel gpmodule: A GP module.
:param ~torch.optim.Optimizer optimizer: A PyTorch optimizer instance.
By default, we use Adam with lr=0.01.
:param callable loss_fn: A loss function which takes inputs are
gpmodule.model, gpmodule.guide, and returns ELBO loss.
By default, loss_fn=TraceMeanField_ELBO().differentiable_loss.
:param bool retain_graph: An optional flag of torch.autograd.backward.
:param int num_steps: Number of steps to run SVI.
:returns: a list of losses during the training procedure
:rtype: list
"""
optimizer = (
if optimizer is None
else optimizer
)
# TODO: add support for JIT loss
loss_fn = TraceMeanField_ELBO().differentiable_loss if loss_fn is None else loss_fn

def closure():