Gaussian Processes¶
See the Gaussian Processes tutorial for an introduction.
- class Parameterized[source]¶
Bases:
pyro.nn.module.PyroModule
A wrapper of
PyroModule
whose parameters can be set constraints, set priors.By default, when we set a prior to a parameter, an auto Delta guide will be created. We can use the method
autoguide()
to setup other auto guides.Example:
>>> class Linear(Parameterized): ... def __init__(self, a, b): ... super().__init__() ... self.a = Parameter(a) ... self.b = Parameter(b) ... ... def forward(self, x): ... return self.a * x + self.b ... >>> linear = Linear(torch.tensor(1.), torch.tensor(0.)) >>> linear.a = PyroParam(torch.tensor(1.), constraints.positive) >>> linear.b = PyroSample(dist.Normal(0, 1)) >>> linear.autoguide("b", dist.Normal) >>> assert "a_unconstrained" in dict(linear.named_parameters()) >>> assert "b_loc" in dict(linear.named_parameters()) >>> assert "b_scale_unconstrained" in dict(linear.named_parameters())
Note that by default, data of a parameter is a float
torch.Tensor
(unless we usetorch.set_default_dtype()
to change default tensor type). To cast these parameters to a correct data type or GPU device, we can call methods such asdouble()
orcuda()
. Seetorch.nn.Module
for more information.- set_prior(name, prior)[source]¶
Sets prior for a parameter.
- Parameters
name (str) – Name of the parameter.
prior (Distribution) – A Pyro prior distribution.
- autoguide(name, dist_constructor)[source]¶
Sets an autoguide for an existing parameter with name
name
(mimic the behavior of modulepyro.infer.autoguide
).Note
dist_constructor should be one of
Delta
,Normal
, andMultivariateNormal
. More distribution constructor will be supported in the future if needed.- Parameters
name (str) – Name of the parameter.
dist_constructor – A
Distribution
constructor.
- set_mode(mode)[source]¶
Sets
mode
of this object to be able to use its parameters in stochastic functions. Ifmode="model"
, a parameter will get its value from its prior. Ifmode="guide"
, the value will be drawn from its guide.Note
This method automatically sets
mode
for submodules which belong toParameterized
class.- Parameters
mode (str) – Either “model” or “guide”.
- property mode¶
Models¶
GPModel¶
- class GPModel(X, y, kernel, mean_function=None, jitter=1e-06)[source]¶
Bases:
pyro.contrib.gp.parameterized.Parameterized
Base class for Gaussian Process models.
The core of a Gaussian Process is a covariance function \(k\) which governs the similarity between input points. Given \(k\), we can establish a distribution over functions \(f\) by a multivarite normal distribution
\[p(f(X)) = \mathcal{N}(0, k(X, X)),\]where \(X\) is any set of input points and \(k(X, X)\) is a covariance matrix whose entries are outputs \(k(x, z)\) of \(k\) over input pairs \((x, z)\). This distribution is usually denoted by
\[f \sim \mathcal{GP}(0, k).\]Note
Generally, beside a covariance matrix \(k\), a Gaussian Process can also be specified by a mean function \(m\) (which is a zero-value function by default). In that case, its distribution will be
\[p(f(X)) = \mathcal{N}(m(X), k(X, X)).\]Gaussian Process models are
Parameterized
subclasses. So its parameters can be learned, set priors, or fixed by using corresponding methods fromParameterized
. A typical way to define a Gaussian Process model is>>> X = torch.tensor([[1., 5, 3], [4, 3, 7]]) >>> y = torch.tensor([2., 1]) >>> kernel = gp.kernels.RBF(input_dim=3) >>> kernel.variance = pyro.nn.PyroSample(dist.Uniform(torch.tensor(0.5), torch.tensor(1.5))) >>> kernel.lengthscale = pyro.nn.PyroSample(dist.Uniform(torch.tensor(1.0), torch.tensor(3.0))) >>> gpr = gp.models.GPRegression(X, y, kernel)
There are two ways to train a Gaussian Process model:
Using an MCMC algorithm (in module
pyro.infer.mcmc
) onmodel()
to get posterior samples for the Gaussian Process’s parameters. For example:>>> hmc_kernel = HMC(gpr.model) >>> mcmc = MCMC(hmc_kernel, num_samples=10) >>> mcmc.run() >>> ls_name = "kernel.lengthscale" >>> posterior_ls = mcmc.get_samples()[ls_name]
Using a variational inference on the pair
model()
,guide()
:>>> optimizer = torch.optim.Adam(gpr.parameters(), lr=0.01) >>> loss_fn = pyro.infer.TraceMeanField_ELBO().differentiable_loss >>> >>> for i in range(1000): ... optimizer.zero_grad() ... loss = loss_fn(gpr.model, gpr.guide) ... loss.backward() ... optimizer.step()
To give a prediction on new dataset, simply use
forward()
like any PyTorchtorch.nn.Module
:>>> Xnew = torch.tensor([[2., 3, 1]]) >>> f_loc, f_cov = gpr(Xnew, full_cov=True)
Reference:
[1] Gaussian Processes for Machine Learning, Carl E. Rasmussen, Christopher K. I. Williams
- Parameters
X (torch.Tensor) – A input data for training. Its first dimension is the number of data points.
y (torch.Tensor) – An output data for training. Its last dimension is the number of data points.
kernel (Kernel) – A Pyro kernel object, which is the covariance function \(k\).
mean_function (callable) – An optional mean function \(m\) of this Gaussian process. By default, we use zero mean.
jitter (float) – A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition.
- model()[source]¶
A “model” stochastic function. If
self.y
isNone
, this method returns mean and variance of the Gaussian Process prior.
- guide()[source]¶
A “guide” stochastic function to be used in variational inference methods. It also gives posterior information to the method
forward()
for prediction.
- forward(Xnew, full_cov=False)[source]¶
Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data \(X_{new}\):
\[p(f^* \mid X_{new}, X, y, k, \theta),\]where \(\theta\) are parameters of this model.
Note
Model’s parameters \(\theta\) together with kernel’s parameters have been learned from a training procedure (MCMC or SVI).
- Parameters
Xnew (torch.Tensor) – A input data for testing. Note that
Xnew.shape[1:]
must be the same asX.shape[1:]
.full_cov (bool) – A flag to decide if we want to predict full covariance matrix or just variance.
- Returns
loc and covariance matrix (or variance) of \(p(f^*(X_{new}))\)
- Return type
- set_data(X, y=None)[source]¶
Sets data for Gaussian Process models.
Some examples to utilize this method are:
Batch training on a sparse variational model:
>>> Xu = torch.tensor([[1., 0, 2]]) # inducing input >>> likelihood = gp.likelihoods.Gaussian() >>> vsgp = gp.models.VariationalSparseGP(X, y, kernel, Xu, likelihood) >>> optimizer = torch.optim.Adam(vsgp.parameters(), lr=0.01) >>> loss_fn = pyro.infer.TraceMeanField_ELBO().differentiable_loss >>> batched_X, batched_y = X.split(split_size=10), y.split(split_size=10) >>> for Xi, yi in zip(batched_X, batched_y): ... optimizer.zero_grad() ... vsgp.set_data(Xi, yi) ... loss = loss_fn(vsgp.model, vsgp.guide) ... loss.backward() ... optimizer.step()
Making a two-layer Gaussian Process stochastic function:
>>> gpr1 = gp.models.GPRegression(X, None, kernel) >>> Z, _ = gpr1.model() >>> gpr2 = gp.models.GPRegression(Z, y, kernel) >>> def two_layer_model(): ... Z, _ = gpr1.model() ... gpr2.set_data(Z, y) ... return gpr2.model()
References:
[1] Scalable Variational Gaussian Process Classification, James Hensman, Alexander G. de G. Matthews, Zoubin Ghahramani
[2] Deep Gaussian Processes, Andreas C. Damianou, Neil D. Lawrence
- Parameters
X (torch.Tensor) – A input data for training. Its first dimension is the number of data points.
y (torch.Tensor) – An output data for training. Its last dimension is the number of data points.
GPRegression¶
- class GPRegression(X, y, kernel, noise=None, mean_function=None, jitter=1e-06)[source]¶
Bases:
pyro.contrib.gp.models.model.GPModel
Gaussian Process Regression model.
The core of a Gaussian Process is a covariance function \(k\) which governs the similarity between input points. Given \(k\), we can establish a distribution over functions \(f\) by a multivarite normal distribution
\[p(f(X)) = \mathcal{N}(0, k(X, X)),\]where \(X\) is any set of input points and \(k(X, X)\) is a covariance matrix whose entries are outputs \(k(x, z)\) of \(k\) over input pairs \((x, z)\). This distribution is usually denoted by
\[f \sim \mathcal{GP}(0, k).\]Note
Generally, beside a covariance matrix \(k\), a Gaussian Process can also be specified by a mean function \(m\) (which is a zero-value function by default). In that case, its distribution will be
\[p(f(X)) = \mathcal{N}(m(X), k(X, X)).\]Given inputs \(X\) and their noisy observations \(y\), the Gaussian Process Regression model takes the form
\[\begin{split}f &\sim \mathcal{GP}(0, k(X, X)),\\ y & \sim f + \epsilon,\end{split}\]where \(\epsilon\) is Gaussian noise.
Note
This model has \(\mathcal{O}(N^3)\) complexity for training, \(\mathcal{O}(N^3)\) complexity for testing. Here, \(N\) is the number of train inputs.
Reference:
[1] Gaussian Processes for Machine Learning, Carl E. Rasmussen, Christopher K. I. Williams
- Parameters
X (torch.Tensor) – A input data for training. Its first dimension is the number of data points.
y (torch.Tensor) – An output data for training. Its last dimension is the number of data points.
kernel (Kernel) – A Pyro kernel object, which is the covariance function \(k\).
noise (torch.Tensor) – Variance of Gaussian noise of this model.
mean_function (callable) – An optional mean function \(m\) of this Gaussian process. By default, we use zero mean.
jitter (float) – A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition.
- forward(Xnew, full_cov=False, noiseless=True)[source]¶
Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data \(X_{new}\):
\[p(f^* \mid X_{new}, X, y, k, \epsilon) = \mathcal{N}(loc, cov).\]Note
The noise parameter
noise
(\(\epsilon\)) together with kernel’s parameters have been learned from a training procedure (MCMC or SVI).- Parameters
Xnew (torch.Tensor) – A input data for testing. Note that
Xnew.shape[1:]
must be the same asself.X.shape[1:]
.full_cov (bool) – A flag to decide if we want to predict full covariance matrix or just variance.
noiseless (bool) – A flag to decide if we want to include noise in the prediction output or not.
- Returns
loc and covariance matrix (or variance) of \(p(f^*(X_{new}))\)
- Return type
- iter_sample(noiseless=True)[source]¶
Iteratively constructs a sample from the Gaussian Process posterior.
Recall that at test input points \(X_{new}\), the posterior is multivariate Gaussian distributed with mean and covariance matrix given by
forward()
.This method samples lazily from this multivariate Gaussian. The advantage of this approach is that later query points can depend upon earlier ones. Particularly useful when the querying is to be done by an optimisation routine.
Note
The noise parameter
noise
(\(\epsilon\)) together with kernel’s parameters have been learned from a training procedure (MCMC or SVI).- Parameters
noiseless (bool) – A flag to decide if we want to add sampling noise to the samples beyond the noise inherent in the GP posterior.
- Returns
sampler
- Return type
function
SparseGPRegression¶
- class SparseGPRegression(X, y, kernel, Xu, noise=None, mean_function=None, approx=None, jitter=1e-06)[source]¶
Bases:
pyro.contrib.gp.models.model.GPModel
Sparse Gaussian Process Regression model.
In
GPRegression
model, when the number of input data \(X\) is large, the covariance matrix \(k(X, X)\) will require a lot of computational steps to compute its inverse (for log likelihood and for prediction). By introducing an additional inducing-input parameter \(X_u\), we can reduce computational cost by approximate \(k(X, X)\) by a low-rank Nyström approximation \(Q\) (see reference [1]), where\[Q = k(X, X_u) k(X_u,X_u)^{-1} k(X_u, X).\]Given inputs \(X\), their noisy observations \(y\), and the inducing-input parameters \(X_u\), the model takes the form:
\[\begin{split}u & \sim \mathcal{GP}(0, k(X_u, X_u)),\\ f & \sim q(f \mid X, X_u) = \mathbb{E}_{p(u)}q(f\mid X, X_u, u),\\ y & \sim f + \epsilon,\end{split}\]where \(\epsilon\) is Gaussian noise and the conditional distribution \(q(f\mid X, X_u, u)\) is an approximation of
\[p(f\mid X, X_u, u) = \mathcal{N}(m, k(X, X) - Q),\]whose terms \(m\) and \(k(X, X) - Q\) is derived from the joint multivariate normal distribution:
\[[f, u] \sim \mathcal{GP}(0, k([X, X_u], [X, X_u])).\]This class implements three approximation methods:
Deterministic Training Conditional (DTC):
\[q(f\mid X, X_u, u) = \mathcal{N}(m, 0),\]which in turns will imply
\[f \sim \mathcal{N}(0, Q).\]Fully Independent Training Conditional (FITC):
\[q(f\mid X, X_u, u) = \mathcal{N}(m, diag(k(X, X) - Q)),\]which in turns will correct the diagonal part of the approximation in DTC:
\[f \sim \mathcal{N}(0, Q + diag(k(X, X) - Q)).\]Variational Free Energy (VFE), which is similar to DTC but has an additional trace_term in the model’s log likelihood. This additional term makes “VFE” equivalent to the variational approach in
VariationalSparseGP
(see reference [2]).
Note
This model has \(\mathcal{O}(NM^2)\) complexity for training, \(\mathcal{O}(NM^2)\) complexity for testing. Here, \(N\) is the number of train inputs, \(M\) is the number of inducing inputs.
References:
[1] A Unifying View of Sparse Approximate Gaussian Process Regression, Joaquin Quiñonero-Candela, Carl E. Rasmussen
[2] Variational learning of inducing variables in sparse Gaussian processes, Michalis Titsias
- Parameters
X (torch.Tensor) – A input data for training. Its first dimension is the number of data points.
y (torch.Tensor) – An output data for training. Its last dimension is the number of data points.
kernel (Kernel) – A Pyro kernel object, which is the covariance function \(k\).
Xu (torch.Tensor) – Initial values for inducing points, which are parameters of our model.
noise (torch.Tensor) – Variance of Gaussian noise of this model.
mean_function (callable) – An optional mean function \(m\) of this Gaussian process. By default, we use zero mean.
approx (str) – One of approximation methods: “DTC”, “FITC”, and “VFE” (default).
jitter (float) – A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition.
name (str) – Name of this model.
- forward(Xnew, full_cov=False, noiseless=True)[source]¶
Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data \(X_{new}\):
\[p(f^* \mid X_{new}, X, y, k, X_u, \epsilon) = \mathcal{N}(loc, cov).\]Note
The noise parameter
noise
(\(\epsilon\)), the inducing-point parameterXu
, together with kernel’s parameters have been learned from a training procedure (MCMC or SVI).- Parameters
Xnew (torch.Tensor) – A input data for testing. Note that
Xnew.shape[1:]
must be the same asself.X.shape[1:]
.full_cov (bool) – A flag to decide if we want to predict full covariance matrix or just variance.
noiseless (bool) – A flag to decide if we want to include noise in the prediction output or not.
- Returns
loc and covariance matrix (or variance) of \(p(f^*(X_{new}))\)
- Return type
VariationalGP¶
- class VariationalGP(X, y, kernel, likelihood, mean_function=None, latent_shape=None, whiten=False, jitter=1e-06)[source]¶
Bases:
pyro.contrib.gp.models.model.GPModel
Variational Gaussian Process model.
This model deals with both Gaussian and non-Gaussian likelihoods. Given inputs\(X\) and their noisy observations \(y\), the model takes the form
\[\begin{split}f &\sim \mathcal{GP}(0, k(X, X)),\\ y & \sim p(y) = p(y \mid f) p(f),\end{split}\]where \(p(y \mid f)\) is the likelihood.
We will use a variational approach in this model by approximating \(q(f)\) to the posterior \(p(f\mid y)\). Precisely, \(q(f)\) will be a multivariate normal distribution with two parameters
f_loc
andf_scale_tril
, which will be learned during a variational inference process.Note
This model can be seen as a special version of
VariationalSparseGP
model with \(X_u = X\).Note
This model has \(\mathcal{O}(N^3)\) complexity for training, \(\mathcal{O}(N^3)\) complexity for testing. Here, \(N\) is the number of train inputs. Size of variational parameters is \(\mathcal{O}(N^2)\).
- Parameters
X (torch.Tensor) – A input data for training. Its first dimension is the number of data points.
y (torch.Tensor) – An output data for training. Its last dimension is the number of data points.
kernel (Kernel) – A Pyro kernel object, which is the covariance function \(k\).
likelihood (likelihood Likelihood) – A likelihood object.
mean_function (callable) – An optional mean function \(m\) of this Gaussian process. By default, we use zero mean.
latent_shape (torch.Size) – Shape for latent processes (batch_shape of \(q(f)\)). By default, it equals to output batch shape
y.shape[:-1]
. For the multi-class classification problems,latent_shape[-1]
should corresponse to the number of classes.whiten (bool) – A flag to tell if variational parameters
f_loc
andf_scale_tril
are transformed by the inverse ofLff
, whereLff
is the lower triangular decomposition of \(kernel(X, X)\). Enable this flag will help optimization.jitter (float) – A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition.
- forward(Xnew, full_cov=False)[source]¶
Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data \(X_{new}\):
\[p(f^* \mid X_{new}, X, y, k, f_{loc}, f_{scale\_tril}) = \mathcal{N}(loc, cov).\]Note
Variational parameters
f_loc
,f_scale_tril
, together with kernel’s parameters have been learned from a training procedure (MCMC or SVI).- Parameters
Xnew (torch.Tensor) – A input data for testing. Note that
Xnew.shape[1:]
must be the same asself.X.shape[1:]
.full_cov (bool) – A flag to decide if we want to predict full covariance matrix or just variance.
- Returns
loc and covariance matrix (or variance) of \(p(f^*(X_{new}))\)
- Return type
VariationalSparseGP¶
- class VariationalSparseGP(X, y, kernel, Xu, likelihood, mean_function=None, latent_shape=None, num_data=None, whiten=False, jitter=1e-06)[source]¶
Bases:
pyro.contrib.gp.models.model.GPModel
Variational Sparse Gaussian Process model.
In
VariationalGP
model, when the number of input data \(X\) is large, the covariance matrix \(k(X, X)\) will require a lot of computational steps to compute its inverse (for log likelihood and for prediction). This model introduces an additional inducing-input parameter \(X_u\) to solve that problem. Given inputs \(X\), their noisy observations \(y\), and the inducing-input parameters \(X_u\), the model takes the form:\[\begin{split}[f, u] &\sim \mathcal{GP}(0, k([X, X_u], [X, X_u])),\\ y & \sim p(y) = p(y \mid f) p(f),\end{split}\]where \(p(y \mid f)\) is the likelihood.
We will use a variational approach in this model by approximating \(q(f,u)\) to the posterior \(p(f,u \mid y)\). Precisely, \(q(f) = p(f\mid u)q(u)\), where \(q(u)\) is a multivariate normal distribution with two parameters
u_loc
andu_scale_tril
, which will be learned during a variational inference process.Note
This model can be learned using MCMC method as in reference [2]. See also
GPModel
.Note
This model has \(\mathcal{O}(NM^2)\) complexity for training, \(\mathcal{O}(M^3)\) complexity for testing. Here, \(N\) is the number of train inputs, \(M\) is the number of inducing inputs. Size of variational parameters is \(\mathcal{O}(M^2)\).
References:
[1] Scalable variational Gaussian process classification, James Hensman, Alexander G. de G. Matthews, Zoubin Ghahramani
[2] MCMC for Variationally Sparse Gaussian Processes, James Hensman, Alexander G. de G. Matthews, Maurizio Filippone, Zoubin Ghahramani
- Parameters
X (torch.Tensor) – A input data for training. Its first dimension is the number of data points.
y (torch.Tensor) – An output data for training. Its last dimension is the number of data points.
kernel (Kernel) – A Pyro kernel object, which is the covariance function \(k\).
Xu (torch.Tensor) – Initial values for inducing points, which are parameters of our model.
likelihood (likelihood Likelihood) – A likelihood object.
mean_function (callable) – An optional mean function \(m\) of this Gaussian process. By default, we use zero mean.
latent_shape (torch.Size) – Shape for latent processes (batch_shape of \(q(u)\)). By default, it equals to output batch shape
y.shape[:-1]
. For the multi-class classification problems,latent_shape[-1]
should corresponse to the number of classes.num_data (int) – The size of full training dataset. It is useful for training this model with mini-batch.
whiten (bool) – A flag to tell if variational parameters
u_loc
andu_scale_tril
are transformed by the inverse ofLuu
, whereLuu
is the lower triangular decomposition of \(kernel(X_u, X_u)\). Enable this flag will help optimization.jitter (float) – A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition.
- forward(Xnew, full_cov=False)[source]¶
Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data \(X_{new}\):
\[p(f^* \mid X_{new}, X, y, k, X_u, u_{loc}, u_{scale\_tril}) = \mathcal{N}(loc, cov).\]Note
Variational parameters
u_loc
,u_scale_tril
, the inducing-point parameterXu
, together with kernel’s parameters have been learned from a training procedure (MCMC or SVI).- Parameters
Xnew (torch.Tensor) – A input data for testing. Note that
Xnew.shape[1:]
must be the same asself.X.shape[1:]
.full_cov (bool) – A flag to decide if we want to predict full covariance matrix or just variance.
- Returns
loc and covariance matrix (or variance) of \(p(f^*(X_{new}))\)
- Return type
GPLVM¶
- class GPLVM(base_model)[source]¶
Bases:
pyro.contrib.gp.parameterized.Parameterized
Gaussian Process Latent Variable Model (GPLVM) model.
GPLVM is a Gaussian Process model with its train input data is a latent variable. This model is useful for dimensional reduction of high dimensional data. Assume the mapping from low dimensional latent variable to is a Gaussian Process instance. Then the high dimensional data will play the role of train output
y
and our target is to learn latent inputs which best explainy
. For the purpose of dimensional reduction, latent inputs should have lower dimensions thany
.We follows reference [1] to put a unit Gaussian prior to the input and approximate its posterior by a multivariate normal distribution with two variational parameters:
X_loc
andX_scale_tril
.For example, we can do dimensional reduction on Iris dataset as follows:
>>> # With y as the 2D Iris data of shape 150x4 and we want to reduce its dimension >>> # to a tensor X of shape 150x2, we will use GPLVM.
>>> # First, define the initial values for X parameter: >>> X_init = torch.zeros(150, 2) >>> # Then, define a Gaussian Process model with input X_init and output y: >>> kernel = gp.kernels.RBF(input_dim=2, lengthscale=torch.ones(2)) >>> Xu = torch.zeros(20, 2) # initial inducing inputs of sparse model >>> gpmodule = gp.models.SparseGPRegression(X_init, y, kernel, Xu) >>> # Finally, wrap gpmodule by GPLVM, optimize, and get the "learned" mean of X: >>> gplvm = gp.models.GPLVM(gpmodule) >>> gp.util.train(gplvm) >>> X = gplvm.X
Reference:
[1] Bayesian Gaussian Process Latent Variable Model Michalis K. Titsias, Neil D. Lawrence
- Parameters
base_model (GPModel) – A Pyro Gaussian Process model object. Note that
base_model.X
will be the initial value for the variational parameterX_loc
.
Kernels¶
Kernel¶
- class Kernel(input_dim, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.parameterized.Parameterized
Base class for kernels used in this Gaussian Process module.
Every inherited class should implement a
forward()
pass which takes inputs \(X\), \(Z\) and returns their covariance matrix.To construct a new kernel from the old ones, we can use methods
add()
,mul()
,exp()
,warp()
,vertical_scale()
.References:
[1] Gaussian Processes for Machine Learning, Carl E. Rasmussen, Christopher K. I. Williams
- Parameters
input_dim (int) – Number of feature dimensions of inputs.
variance (torch.Tensor) – Variance parameter of this kernel.
active_dims (list) – List of feature dimensions of the input which the kernel acts on.
- forward(X, Z=None, diag=False)[source]¶
Calculates covariance matrix of inputs on active dimensionals.
- Parameters
X (torch.Tensor) – A 2D tensor with shape \(N \times input\_dim\).
Z (torch.Tensor) – An (optional) 2D tensor with shape \(M \times input\_dim\).
diag (bool) – A flag to decide if we want to return full covariance matrix or just its diagonal part.
- Returns
covariance matrix of \(X\) and \(Z\) with shape \(N \times M\)
- Return type
Brownian¶
- class Brownian(input_dim, variance=None, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.kernels.kernel.Kernel
This kernel correponds to a two-sided Brownion motion (Wiener process):
\(k(x,z)=\begin{cases}\sigma^2\min(|x|,|z|),& \text{if } x\cdot z\ge 0\\ 0, & \text{otherwise}. \end{cases}\)
Note that the input dimension of this kernel must be 1.
Reference:
[1] Theory and Statistical Applications of Stochastic Processes, Yuliya Mishura, Georgiy Shevchenko
Combination¶
- class Combination(kern0, kern1)[source]¶
Bases:
pyro.contrib.gp.kernels.kernel.Kernel
Base class for kernels derived from a combination of kernels.
- Parameters
kern0 (Kernel) – First kernel to combine.
kern1 (Kernel or numbers.Number) – Second kernel to combine.
Constant¶
Coregionalize¶
- class Coregionalize(input_dim, rank=None, components=None, diagonal=None, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.kernels.kernel.Kernel
A kernel for the linear model of coregionalization \(k(x,z) = x^T (W W^T + D) z\) where \(W\) is an
input_dim
-by-rank
matrix and typicallyrank < input_dim
, andD
is a diagonal matrix.This generalizes the
Linear
kernel to multiple features with a low-rank-plus-diagonal weight matrix. The typical use case is for modeling correlations among outputs of a multi-output GP, where outputs are coded as distinct data points with one-hot coded features denoting which output each datapoint represents.If only
rank
is specified, the kernel(W W^T + D)
will be randomly initialized to a matrix with expected value the identity matrix.References:
- [1] Mauricio A. Alvarez, Lorenzo Rosasco, Neil D. Lawrence (2012)
Kernels for Vector-Valued Functions: a Review
- Parameters
input_dim (int) – Number of feature dimensions of inputs.
rank (int) – Optional rank. This is only used if
components
is unspecified. If neigherrank
norcomponents
is specified, thenrank
defaults toinput_dim
.components (torch.Tensor) – An optional
(input_dim, rank)
shaped matrix that maps features torank
-many components. If unspecified, this will be randomly initialized.diagonal (torch.Tensor) – An optional vector of length
input_dim
. If unspecified, this will be set to constant0.5
.active_dims (list) – List of feature dimensions of the input which the kernel acts on.
name (str) – Name of the kernel.
Cosine¶
- class Cosine(input_dim, variance=None, lengthscale=None, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.kernels.isotropic.Isotropy
Implementation of Cosine kernel:
\(k(x,z) = \sigma^2 \cos\left(\frac{|x-z|}{l}\right).\)
- Parameters
lengthscale (torch.Tensor) – Length-scale parameter of this kernel.
DotProduct¶
- class DotProduct(input_dim, variance=None, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.kernels.kernel.Kernel
Base class for kernels which are functions of \(x \cdot z\).
Exponent¶
Exponential¶
Isotropy¶
- class Isotropy(input_dim, variance=None, lengthscale=None, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.kernels.kernel.Kernel
Base class for a family of isotropic covariance kernels which are functions of the distance \(|x-z|/l\), where \(l\) is the length-scale parameter.
By default, the parameter
lengthscale
has size 1. To use the isotropic version (different lengthscale for each dimension), make sure thatlengthscale
has size equal toinput_dim
.- Parameters
lengthscale (torch.Tensor) – Length-scale parameter of this kernel.
Linear¶
- class Linear(input_dim, variance=None, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.kernels.dot_product.DotProduct
Implementation of Linear kernel:
\(k(x, z) = \sigma^2 x \cdot z.\)
Doing Gaussian Process regression with linear kernel is equivalent to doing a linear regression.
Note
Here we implement the homogeneous version. To use the inhomogeneous version, consider using
Polynomial
kernel withdegree=1
or making aSum
with aConstant
kernel.
Matern32¶
- class Matern32(input_dim, variance=None, lengthscale=None, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.kernels.isotropic.Isotropy
Implementation of Matern32 kernel:
\(k(x, z) = \sigma^2\left(1 + \sqrt{3} \times \frac{|x-z|}{l}\right) \exp\left(-\sqrt{3} \times \frac{|x-z|}{l}\right).\)
Matern52¶
- class Matern52(input_dim, variance=None, lengthscale=None, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.kernels.isotropic.Isotropy
Implementation of Matern52 kernel:
\(k(x,z)=\sigma^2\left(1+\sqrt{5}\times\frac{|x-z|}{l}+\frac{5}{3}\times \frac{|x-z|^2}{l^2}\right)\exp\left(-\sqrt{5} \times \frac{|x-z|}{l}\right).\)
Periodic¶
- class Periodic(input_dim, variance=None, lengthscale=None, period=None, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.kernels.kernel.Kernel
Implementation of Periodic kernel:
\(k(x,z)=\sigma^2\exp\left(-2\times\frac{\sin^2(\pi(x-z)/p)}{l^2}\right),\)
where \(p\) is the
period
parameter.References:
[1] Introduction to Gaussian processes, David J.C. MacKay
- Parameters
lengthscale (torch.Tensor) – Length scale parameter of this kernel.
period (torch.Tensor) – Period parameter of this kernel.
Polynomial¶
- class Polynomial(input_dim, variance=None, bias=None, degree=1, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.kernels.dot_product.DotProduct
Implementation of Polynomial kernel:
\(k(x, z) = \sigma^2(\text{bias} + x \cdot z)^d.\)
- Parameters
bias (torch.Tensor) – Bias parameter of this kernel. Should be positive.
degree (int) – Degree \(d\) of the polynomial.
Product¶
RBF¶
- class RBF(input_dim, variance=None, lengthscale=None, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.kernels.isotropic.Isotropy
Implementation of Radial Basis Function kernel:
\(k(x,z) = \sigma^2\exp\left(-0.5 \times \frac{|x-z|^2}{l^2}\right).\)
Note
This kernel also has name Squared Exponential in literature.
RationalQuadratic¶
- class RationalQuadratic(input_dim, variance=None, lengthscale=None, scale_mixture=None, active_dims=None)[source]¶
Bases:
pyro.contrib.gp.kernels.isotropic.Isotropy
Implementation of RationalQuadratic kernel:
\(k(x, z) = \sigma^2 \left(1 + 0.5 \times \frac{|x-z|^2}{\alpha l^2} \right)^{-\alpha}.\)
- Parameters
scale_mixture (torch.Tensor) – Scale mixture (\(\alpha\)) parameter of this kernel. Should have size 1.
Sum¶
Transforming¶
VerticalScaling¶
- class VerticalScaling(kern, vscaling_fn)[source]¶
Bases:
pyro.contrib.gp.kernels.kernel.Transforming
Creates a new kernel according to
\(k_{new}(x, z) = f(x)k(x, z)f(z),\)
where \(f\) is a function.
- Parameters
vscaling_fn (callable) – A vertical scaling function \(f\).
Warping¶
- class Warping(kern, iwarping_fn=None, owarping_coef=None)[source]¶
Bases:
pyro.contrib.gp.kernels.kernel.Transforming
Creates a new kernel according to
\(k_{new}(x, z) = q(k(f(x), f(z))),\)
where \(f\) is an function and \(q\) is a polynomial with non-negative coefficients
owarping_coef
.We can take advantage of \(f\) to combine a Gaussian Process kernel with a deep learning architecture. For example:
>>> linear = torch.nn.Linear(10, 3) >>> # register its parameters to Pyro's ParamStore and wrap it by lambda >>> # to call the primitive pyro.module each time we use the linear function >>> pyro_linear_fn = lambda x: pyro.module("linear", linear)(x) >>> kernel = gp.kernels.Matern52(input_dim=3, lengthscale=torch.ones(3)) >>> warped_kernel = gp.kernels.Warping(kernel, pyro_linear_fn)
Reference:
[1] Deep Kernel Learning, Andrew G. Wilson, Zhiting Hu, Ruslan Salakhutdinov, Eric P. Xing
- Parameters
iwarping_fn (callable) – An input warping function \(f\).
owarping_coef (list) – A list of coefficients of the output warping polynomial. These coefficients must be non-negative.
WhiteNoise¶
Likelihoods¶
Likelihood¶
- class Likelihood[source]¶
Bases:
pyro.contrib.gp.parameterized.Parameterized
Base class for likelihoods used in Gaussian Process.
Every inherited class should implement a forward pass which takes an input \(f\) and returns a sample \(y\).
- forward(f_loc, f_var, y=None)[source]¶
Samples \(y\) given \(f_{loc}\), \(f_{var}\).
- Parameters
f_loc (torch.Tensor) – Mean of latent function output.
f_var (torch.Tensor) – Variance of latent function output.
y (torch.Tensor) – Training output tensor.
- Returns
a tensor sampled from likelihood
- Return type
Binary¶
- class Binary(response_function=None)[source]¶
Bases:
pyro.contrib.gp.likelihoods.likelihood.Likelihood
Implementation of Binary likelihood, which is used for binary classification problems.
Binary likelihood uses
Bernoulli
distribution, so the output ofresponse_function
should be in range \((0,1)\). By default, we use sigmoid function.- Parameters
response_function (callable) – A mapping to correct domain for Binary likelihood.
- forward(f_loc, f_var, y=None)[source]¶
Samples \(y\) given \(f_{loc}\), \(f_{var}\) according to
\[\begin{split}f & \sim \mathbb{Normal}(f_{loc}, f_{var}),\\ y & \sim \mathbb{Bernoulli}(f).\end{split}\]Note
The log likelihood is estimated using Monte Carlo with 1 sample of \(f\).
- Parameters
f_loc (torch.Tensor) – Mean of latent function output.
f_var (torch.Tensor) – Variance of latent function output.
y (torch.Tensor) – Training output tensor.
- Returns
a tensor sampled from likelihood
- Return type
Gaussian¶
- class Gaussian(variance=None)[source]¶
Bases:
pyro.contrib.gp.likelihoods.likelihood.Likelihood
Implementation of Gaussian likelihood, which is used for regression problems.
Gaussian likelihood uses
Normal
distribution.- Parameters
variance (torch.Tensor) – A variance parameter, which plays the role of
noise
in regression problems.
- forward(f_loc, f_var, y=None)[source]¶
Samples \(y\) given \(f_{loc}\), \(f_{var}\) according to
\[y \sim \mathbb{Normal}(f_{loc}, f_{var} + \epsilon),\]where \(\epsilon\) is the
variance
parameter of this likelihood.- Parameters
f_loc (torch.Tensor) – Mean of latent function output.
f_var (torch.Tensor) – Variance of latent function output.
y (torch.Tensor) – Training output tensor.
- Returns
a tensor sampled from likelihood
- Return type
MultiClass¶
- class MultiClass(num_classes, response_function=None)[source]¶
Bases:
pyro.contrib.gp.likelihoods.likelihood.Likelihood
Implementation of MultiClass likelihood, which is used for multi-class classification problems.
MultiClass likelihood uses
Categorical
distribution, soresponse_function
should normalize its input’s rightmost axis. By default, we use softmax function.- Parameters
num_classes (int) – Number of classes for prediction.
response_function (callable) – A mapping to correct domain for MultiClass likelihood.
- forward(f_loc, f_var, y=None)[source]¶
Samples \(y\) given \(f_{loc}\), \(f_{var}\) according to
\[\begin{split}f & \sim \mathbb{Normal}(f_{loc}, f_{var}),\\ y & \sim \mathbb{Categorical}(f).\end{split}\]Note
The log likelihood is estimated using Monte Carlo with 1 sample of \(f\).
- Parameters
f_loc (torch.Tensor) – Mean of latent function output.
f_var (torch.Tensor) – Variance of latent function output.
y (torch.Tensor) – Training output tensor.
- Returns
a tensor sampled from likelihood
- Return type
Poisson¶
- class Poisson(response_function=None)[source]¶
Bases:
pyro.contrib.gp.likelihoods.likelihood.Likelihood
Implementation of Poisson likelihood, which is used for count data.
Poisson likelihood uses the
Poisson
distribution, so the output ofresponse_function
should be positive. By default, we usetorch.exp()
as response function, corresponding to a log-Gaussian Cox process.- Parameters
response_function (callable) – A mapping to positive real numbers.
- forward(f_loc, f_var, y=None)[source]¶
Samples \(y\) given \(f_{loc}\), \(f_{var}\) according to
\[\begin{split}f & \sim \mathbb{Normal}(f_{loc}, f_{var}),\\ y & \sim \mathbb{Poisson}(\exp(f)).\end{split}\]Note
The log likelihood is estimated using Monte Carlo with 1 sample of \(f\).
- Parameters
f_loc (torch.Tensor) – Mean of latent function output.
f_var (torch.Tensor) – Variance of latent function output.
y (torch.Tensor) – Training output tensor.
- Returns
a tensor sampled from likelihood
- Return type
Parameterized¶
- class Parameterized[source]¶
Bases:
pyro.nn.module.PyroModule
A wrapper of
PyroModule
whose parameters can be set constraints, set priors.By default, when we set a prior to a parameter, an auto Delta guide will be created. We can use the method
autoguide()
to setup other auto guides.Example:
>>> class Linear(Parameterized): ... def __init__(self, a, b): ... super().__init__() ... self.a = Parameter(a) ... self.b = Parameter(b) ... ... def forward(self, x): ... return self.a * x + self.b ... >>> linear = Linear(torch.tensor(1.), torch.tensor(0.)) >>> linear.a = PyroParam(torch.tensor(1.), constraints.positive) >>> linear.b = PyroSample(dist.Normal(0, 1)) >>> linear.autoguide("b", dist.Normal) >>> assert "a_unconstrained" in dict(linear.named_parameters()) >>> assert "b_loc" in dict(linear.named_parameters()) >>> assert "b_scale_unconstrained" in dict(linear.named_parameters())
Note that by default, data of a parameter is a float
torch.Tensor
(unless we usetorch.set_default_dtype()
to change default tensor type). To cast these parameters to a correct data type or GPU device, we can call methods such asdouble()
orcuda()
. Seetorch.nn.Module
for more information.- set_prior(name, prior)[source]¶
Sets prior for a parameter.
- Parameters
name (str) – Name of the parameter.
prior (Distribution) – A Pyro prior distribution.
- autoguide(name, dist_constructor)[source]¶
Sets an autoguide for an existing parameter with name
name
(mimic the behavior of modulepyro.infer.autoguide
).Note
dist_constructor should be one of
Delta
,Normal
, andMultivariateNormal
. More distribution constructor will be supported in the future if needed.- Parameters
name (str) – Name of the parameter.
dist_constructor – A
Distribution
constructor.
- set_mode(mode)[source]¶
Sets
mode
of this object to be able to use its parameters in stochastic functions. Ifmode="model"
, a parameter will get its value from its prior. Ifmode="guide"
, the value will be drawn from its guide.Note
This method automatically sets
mode
for submodules which belong toParameterized
class.- Parameters
mode (str) – Either “model” or “guide”.
- property mode¶
Util¶
- conditional(Xnew, X, kernel, f_loc, f_scale_tril=None, Lff=None, full_cov=False, whiten=False, jitter=1e-06)[source]¶
Given \(X_{new}\), predicts loc and covariance matrix of the conditional multivariate normal distribution
\[p(f^*(X_{new}) \mid X, k, f_{loc}, f_{scale\_tril}).\]Here
f_loc
andf_scale_tril
are variation parameters of the variational distribution\[q(f \mid f_{loc}, f_{scale\_tril}) \sim p(f | X, y),\]where \(f\) is the function value of the Gaussian Process given input \(X\)
\[p(f(X)) \sim \mathcal{N}(0, k(X, X))\]and \(y\) is computed from \(f\) by some likelihood function \(p(y|f)\).
In case
f_scale_tril=None
, we consider \(f = f_{loc}\) and computes\[p(f^*(X_{new}) \mid X, k, f).\]In case
f_scale_tril
is notNone
, we follow the derivation from reference [1]. For the casef_scale_tril=None
, we follow the popular reference [2].References:
[1] Sparse GPs: approximate the posterior, not the model
[2] Gaussian Processes for Machine Learning, Carl E. Rasmussen, Christopher K. I. Williams
- Parameters
Xnew (torch.Tensor) – A new input data.
X (torch.Tensor) – An input data to be conditioned on.
kernel (Kernel) – A Pyro kernel object.
f_loc (torch.Tensor) – Mean of \(q(f)\). In case
f_scale_tril=None
, \(f_{loc} = f\).f_scale_tril (torch.Tensor) – Lower triangular decomposition of covariance matrix of \(q(f)\)’s .
Lff (torch.Tensor) – Lower triangular decomposition of \(kernel(X, X)\) (optional).
full_cov (bool) – A flag to decide if we want to return full covariance matrix or just variance.
whiten (bool) – A flag to tell if
f_loc
andf_scale_tril
are already transformed by the inverse ofLff
.jitter (float) – A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition.
- Returns
loc and covariance matrix (or variance) of \(p(f^*(X_{new}))\)
- Return type
- train(gpmodule, optimizer=None, loss_fn=None, retain_graph=None, num_steps=1000)[source]¶
A helper to optimize parameters for a GP module.
- Parameters
gpmodule (GPModel) – A GP module.
optimizer (Optimizer) – A PyTorch optimizer instance. By default, we use Adam with
lr=0.01
.loss_fn (callable) – A loss function which takes inputs are
gpmodule.model
,gpmodule.guide
, and returns ELBO loss. By default,loss_fn=TraceMeanField_ELBO().differentiable_loss
.retain_graph (bool) – An optional flag of
torch.autograd.backward
.num_steps (int) – Number of steps to run SVI.
- Returns
a list of losses during the training procedure
- Return type