Gaussian Processes

See the Gaussian Processes tutorial for an introduction.

class Parameterized[source]

Bases: pyro.nn.module.PyroModule

A wrapper of PyroModule whose parameters can be set constraints, set priors.

By default, when we set a prior to a parameter, an auto Delta guide will be created. We can use the method autoguide() to setup other auto guides.

Example:

>>> class Linear(Parameterized):
...     def __init__(self, a, b):
...         super().__init__()
...         self.a = Parameter(a)
...         self.b = Parameter(b)
...
...     def forward(self, x):
...         return self.a * x + self.b
...
>>> linear = Linear(torch.tensor(1.), torch.tensor(0.))
>>> linear.a = PyroParam(torch.tensor(1.), constraints.positive)
>>> linear.b = PyroSample(dist.Normal(0, 1))
>>> linear.autoguide("b", dist.Normal)
>>> assert "a_unconstrained" in dict(linear.named_parameters())
>>> assert "b_loc" in dict(linear.named_parameters())
>>> assert "b_scale_unconstrained" in dict(linear.named_parameters())

Note that by default, data of a parameter is a float torch.Tensor (unless we use torch.set_default_tensor_type() to change default tensor type). To cast these parameters to a correct data type or GPU device, we can call methods such as double() or cuda(). See torch.nn.Module for more information.

set_prior(name, prior)[source]

Sets prior for a parameter.

Parameters:
  • name (str) – Name of the parameter.
  • prior (Distribution) – A Pyro prior distribution.
autoguide(name, dist_constructor)[source]

Sets an autoguide for an existing parameter with name name (mimic the behavior of module pyro.infer.autoguide).

Note

dist_constructor should be one of Delta, Normal, and MultivariateNormal. More distribution constructor will be supported in the future if needed.

Parameters:
  • name (str) – Name of the parameter.
  • dist_constructor – A Distribution constructor.
set_mode(mode)[source]

Sets mode of this object to be able to use its parameters in stochastic functions. If mode="model", a parameter will get its value from its prior. If mode="guide", the value will be drawn from its guide.

Note

This method automatically sets mode for submodules which belong to Parameterized class.

Parameters:mode (str) – Either “model” or “guide”.
mode

Models

GPModel

class GPModel(X, y, kernel, mean_function=None, jitter=1e-06)[source]

Bases: pyro.contrib.gp.parameterized.Parameterized

Base class for Gaussian Process models.

The core of a Gaussian Process is a covariance function \(k\) which governs the similarity between input points. Given \(k\), we can establish a distribution over functions \(f\) by a multivarite normal distribution

\[p(f(X)) = \mathcal{N}(0, k(X, X)),\]

where \(X\) is any set of input points and \(k(X, X)\) is a covariance matrix whose entries are outputs \(k(x, z)\) of \(k\) over input pairs \((x, z)\). This distribution is usually denoted by

\[f \sim \mathcal{GP}(0, k).\]

Note

Generally, beside a covariance matrix \(k\), a Gaussian Process can also be specified by a mean function \(m\) (which is a zero-value function by default). In that case, its distribution will be

\[p(f(X)) = \mathcal{N}(m(X), k(X, X)).\]

Gaussian Process models are Parameterized subclasses. So its parameters can be learned, set priors, or fixed by using corresponding methods from Parameterized. A typical way to define a Gaussian Process model is

>>> X = torch.tensor([[1., 5, 3], [4, 3, 7]])
>>> y = torch.tensor([2., 1])
>>> kernel = gp.kernels.RBF(input_dim=3)
>>> kernel.variance = pyro.nn.PyroSample(dist.Uniform(torch.tensor(0.5), torch.tensor(1.5)))
>>> kernel.lengthscale = pyro.nn.PyroSample(dist.Uniform(torch.tensor(1.0), torch.tensor(3.0)))
>>> gpr = gp.models.GPRegression(X, y, kernel)

There are two ways to train a Gaussian Process model:

  • Using an MCMC algorithm (in module pyro.infer.mcmc) on model() to get posterior samples for the Gaussian Process’s parameters. For example:

    >>> hmc_kernel = HMC(gpr.model)
    >>> mcmc = MCMC(hmc_kernel, num_samples=10)
    >>> mcmc.run()
    >>> ls_name = "kernel.lengthscale"
    >>> posterior_ls = mcmc.get_samples()[ls_name]
    
  • Using a variational inference on the pair model(), guide():

    >>> optimizer = torch.optim.Adam(gpr.parameters(), lr=0.01)
    >>> loss_fn = pyro.infer.TraceMeanField_ELBO().differentiable_loss
    >>>
    >>> for i in range(1000):
    ...     svi.step()  # doctest: +SKIP
    ...     optimizer.zero_grad()
    ...     loss = loss_fn(gpr.model, gpr.guide)  # doctest: +SKIP
    ...     loss.backward()  # doctest: +SKIP
    ...     optimizer.step()
    

To give a prediction on new dataset, simply use forward() like any PyTorch torch.nn.Module:

>>> Xnew = torch.tensor([[2., 3, 1]])
>>> f_loc, f_cov = gpr(Xnew, full_cov=True)

Reference:

[1] Gaussian Processes for Machine Learning, Carl E. Rasmussen, Christopher K. I. Williams

Parameters:
  • X (torch.Tensor) – A input data for training. Its first dimension is the number of data points.
  • y (torch.Tensor) – An output data for training. Its last dimension is the number of data points.
  • kernel (Kernel) – A Pyro kernel object, which is the covariance function \(k\).
  • mean_function (callable) – An optional mean function \(m\) of this Gaussian process. By default, we use zero mean.
  • jitter (float) – A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition.
model()[source]

A “model” stochastic function. If self.y is None, this method returns mean and variance of the Gaussian Process prior.

guide()[source]

A “guide” stochastic function to be used in variational inference methods. It also gives posterior information to the method forward() for prediction.

forward(Xnew, full_cov=False)[source]

Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data \(X_{new}\):

\[p(f^* \mid X_{new}, X, y, k, \theta),\]

where \(\theta\) are parameters of this model.

Note

Model’s parameters \(\theta\) together with kernel’s parameters have been learned from a training procedure (MCMC or SVI).

Parameters:
  • Xnew (torch.Tensor) – A input data for testing. Note that Xnew.shape[1:] must be the same as X.shape[1:].
  • full_cov (bool) – A flag to decide if we want to predict full covariance matrix or just variance.
Returns:

loc and covariance matrix (or variance) of \(p(f^*(X_{new}))\)

Return type:

tuple(torch.Tensor, torch.Tensor)

set_data(X, y=None)[source]

Sets data for Gaussian Process models.

Some examples to utilize this method are:

  • Batch training on a sparse variational model:

    >>> Xu = torch.tensor([[1., 0, 2]])  # inducing input
    >>> likelihood = gp.likelihoods.Gaussian()
    >>> vsgp = gp.models.VariationalSparseGP(X, y, kernel, Xu, likelihood)
    >>> optimizer = torch.optim.Adam(vsgp.parameters(), lr=0.01)
    >>> loss_fn = pyro.infer.TraceMeanField_ELBO().differentiable_loss
    >>> batched_X, batched_y = X.split(split_size=10), y.split(split_size=10)
    >>> for Xi, yi in zip(batched_X, batched_y):
    ...     optimizer.zero_grad()
    ...     vsgp.set_data(Xi, yi)
    ...     svi.step()  # doctest: +SKIP
    ...     loss = loss_fn(vsgp.model, vsgp.guide)  # doctest: +SKIP
    ...     loss.backward()  # doctest: +SKIP
    ...     optimizer.step()
    
  • Making a two-layer Gaussian Process stochastic function:

    >>> gpr1 = gp.models.GPRegression(X, None, kernel)
    >>> Z, _ = gpr1.model()
    >>> gpr2 = gp.models.GPRegression(Z, y, kernel)
    >>> def two_layer_model():
    ...     Z, _ = gpr1.model()
    ...     gpr2.set_data(Z, y)
    ...     return gpr2.model()
    

References:

[1] Scalable Variational Gaussian Process Classification, James Hensman, Alexander G. de G. Matthews, Zoubin Ghahramani

[2] Deep Gaussian Processes, Andreas C. Damianou, Neil D. Lawrence

Parameters:
  • X (torch.Tensor) – A input data for training. Its first dimension is the number of data points.
  • y (torch.Tensor) – An output data for training. Its last dimension is the number of data points.

GPRegression

class GPRegression(X, y, kernel, noise=None, mean_function=None, jitter=1e-06)[source]

Bases: pyro.contrib.gp.models.model.GPModel

Gaussian Process Regression model.

The core of a Gaussian Process is a covariance function \(k\) which governs the similarity between input points. Given \(k\), we can establish a distribution over functions \(f\) by a multivarite normal distribution

\[p(f(X)) = \mathcal{N}(0, k(X, X)),\]

where \(X\) is any set of input points and \(k(X, X)\) is a covariance matrix whose entries are outputs \(k(x, z)\) of \(k\) over input pairs \((x, z)\). This distribution is usually denoted by

\[f \sim \mathcal{GP}(0, k).\]

Note

Generally, beside a covariance matrix \(k\), a Gaussian Process can also be specified by a mean function \(m\) (which is a zero-value function by default). In that case, its distribution will be

\[p(f(X)) = \mathcal{N}(m(X), k(X, X)).\]

Given inputs \(X\) and their noisy observations \(y\), the Gaussian Process Regression model takes the form

\[\begin{split}f &\sim \mathcal{GP}(0, k(X, X)),\\ y & \sim f + \epsilon,\end{split}\]

where \(\epsilon\) is Gaussian noise.

Note

This model has \(\mathcal{O}(N^3)\) complexity for training, \(\mathcal{O}(N^3)\) complexity for testing. Here, \(N\) is the number of train inputs.

Reference:

[1] Gaussian Processes for Machine Learning, Carl E. Rasmussen, Christopher K. I. Williams

Parameters:
  • X (torch.Tensor) – A input data for training. Its first dimension is the number of data points.
  • y (torch.Tensor) – An output data for training. Its last dimension is the number of data points.
  • kernel (Kernel) – A Pyro kernel object, which is the covariance function \(k\).
  • noise (torch.Tensor) – Variance of Gaussian noise of this model.
  • mean_function (callable) – An optional mean function \(m\) of this Gaussian process. By default, we use zero mean.
  • jitter (float) – A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition.
model()[source]
guide()[source]
forward(Xnew, full_cov=False, noiseless=True)[source]

Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data \(X_{new}\):

\[p(f^* \mid X_{new}, X, y, k, \epsilon) = \mathcal{N}(loc, cov).\]

Note

The noise parameter noise (\(\epsilon\)) together with kernel’s parameters have been learned from a training procedure (MCMC or SVI).

Parameters:
  • Xnew (torch.Tensor) – A input data for testing. Note that Xnew.shape[1:] must be the same as self.X.shape[1:].
  • full_cov (bool) – A flag to decide if we want to predict full covariance matrix or just variance.
  • noiseless (bool) – A flag to decide if we want to include noise in the prediction output or not.
Returns:

loc and covariance matrix (or variance) of \(p(f^*(X_{new}))\)

Return type:

tuple(torch.Tensor, torch.Tensor)

iter_sample(noiseless=True)[source]

Iteratively constructs a sample from the Gaussian Process posterior.

Recall that at test input points \(X_{new}\), the posterior is multivariate Gaussian distributed with mean and covariance matrix given by forward().

This method samples lazily from this multivariate Gaussian. The advantage of this approach is that later query points can depend upon earlier ones. Particularly useful when the querying is to be done by an optimisation routine.

Note

The noise parameter noise (\(\epsilon\)) together with kernel’s parameters have been learned from a training procedure (MCMC or SVI).

Parameters:noiseless (bool) – A flag to decide if we want to add sampling noise to the samples beyond the noise inherent in the GP posterior.
Returns:sampler
Return type:function

SparseGPRegression

class SparseGPRegression(X, y, kernel, Xu, noise=None, mean_function=None, approx=None, jitter=1e-06)[source]

Bases: pyro.contrib.gp.models.model.GPModel

Sparse Gaussian Process Regression model.

In GPRegression model, when the number of input data \(X\) is large, the covariance matrix \(k(X, X)\) will require a lot of computational steps to compute its inverse (for log likelihood and for prediction). By introducing an additional inducing-input parameter \(X_u\), we can reduce computational cost by approximate \(k(X, X)\) by a low-rank Nymström approximation \(Q\) (see reference [1]), where

\[Q = k(X, X_u) k(X,X)^{-1} k(X_u, X).\]

Given inputs \(X\), their noisy observations \(y\), and the inducing-input parameters \(X_u\), the model takes the form:

\[\begin{split}u & \sim \mathcal{GP}(0, k(X_u, X_u)),\\ f & \sim q(f \mid X, X_u) = \mathbb{E}_{p(u)}q(f\mid X, X_u, u),\\ y & \sim f + \epsilon,\end{split}\]

where \(\epsilon\) is Gaussian noise and the conditional distribution \(q(f\mid X, X_u, u)\) is an approximation of

\[p(f\mid X, X_u, u) = \mathcal{N}(m, k(X, X) - Q),\]

whose terms \(m\) and \(k(X, X) - Q\) is derived from the joint multivariate normal distribution:

\[[f, u] \sim \mathcal{GP}(0, k([X, X_u], [X, X_u])).\]

This class implements three approximation methods:

  • Deterministic Training Conditional (DTC):

    \[q(f\mid X, X_u, u) = \mathcal{N}(m, 0),\]

    which in turns will imply

    \[f \sim \mathcal{N}(0, Q).\]
  • Fully Independent Training Conditional (FITC):

    \[q(f\mid X, X_u, u) = \mathcal{N}(m, diag(k(X, X) - Q)),\]

    which in turns will correct the diagonal part of the approximation in DTC:

    \[f \sim \mathcal{N}(0, Q + diag(k(X, X) - Q)).\]
  • Variational Free Energy (VFE), which is similar to DTC but has an additional trace_term in the model’s log likelihood. This additional term makes “VFE” equivalent to the variational approach in SparseVariationalGP (see reference [2]).

Note

This model has \(\mathcal{O}(NM^2)\) complexity for training, \(\mathcal{O}(NM^2)\) complexity for testing. Here, \(N\) is the number of train inputs, \(M\) is the number of inducing inputs.

References:

[1] A Unifying View of Sparse Approximate Gaussian Process Regression, Joaquin Quiñonero-Candela, Carl E. Rasmussen

[2] Variational learning of inducing variables in sparse Gaussian processes, Michalis Titsias

Parameters:
  • X (torch.Tensor) – A input data for training. Its first dimension is the number of data points.
  • y (torch.Tensor) – An output data for training. Its last dimension is the number of data points.
  • kernel (Kernel) – A Pyro kernel object, which is the covariance function \(k\).
  • Xu (torch.Tensor) – Initial values for inducing points, which are parameters of our model.
  • noise (torch.Tensor) – Variance of Gaussian noise of this model.
  • mean_function (callable) – An optional mean function \(m\) of this Gaussian process. By default, we use zero mean.
  • approx (str) – One of approximation methods: “DTC”, “FITC”, and “VFE” (default).
  • jitter (float) – A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition.
  • name (str) – Name of this model.
model()[source]
guide()[source]
forward(Xnew, full_cov=False, noiseless=True)[source]

Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data \(X_{new}\):

\[p(f^* \mid X_{new}, X, y, k, X_u, \epsilon) = \mathcal{N}(loc, cov).\]

Note

The noise parameter noise (\(\epsilon\)), the inducing-point parameter Xu, together with kernel’s parameters have been learned from a training procedure (MCMC or SVI).

Parameters:
  • Xnew (torch.Tensor) – A input data for testing. Note that Xnew.shape[1:] must be the same as self.X.shape[1:].
  • full_cov (bool) – A flag to decide if we want to predict full covariance matrix or just variance.
  • noiseless (bool) – A flag to decide if we want to include noise in the prediction output or not.
Returns:

loc and covariance matrix (or variance) of \(p(f^*(X_{new}))\)

Return type:

tuple(torch.Tensor, torch.Tensor)

VariationalGP

class VariationalGP(X, y, kernel, likelihood, mean_function=None, latent_shape=None, whiten=False, jitter=1e-06)[source]

Bases: pyro.contrib.gp.models.model.GPModel

Variational Gaussian Process model.

This model deals with both Gaussian and non-Gaussian likelihoods. Given inputs\(X\) and their noisy observations \(y\), the model takes the form

\[\begin{split}f &\sim \mathcal{GP}(0, k(X, X)),\\ y & \sim p(y) = p(y \mid f) p(f),\end{split}\]

where \(p(y \mid f)\) is the likelihood.

We will use a variational approach in this model by approximating \(q(f)\) to the posterior \(p(f\mid y)\). Precisely, \(q(f)\) will be a multivariate normal distribution with two parameters f_loc and f_scale_tril, which will be learned during a variational inference process.

Note

This model can be seen as a special version of SparseVariationalGP model with \(X_u = X\).

Note

This model has \(\mathcal{O}(N^3)\) complexity for training, \(\mathcal{O}(N^3)\) complexity for testing. Here, \(N\) is the number of train inputs. Size of variational parameters is \(\mathcal{O}(N^2)\).

Parameters:
  • X (torch.Tensor) – A input data for training. Its first dimension is the number of data points.
  • y (torch.Tensor) – An output data for training. Its last dimension is the number of data points.
  • kernel (Kernel) – A Pyro kernel object, which is the covariance function \(k\).
  • Likelihood likelihood (likelihood) – A likelihood object.
  • mean_function (callable) – An optional mean function \(m\) of this Gaussian process. By default, we use zero mean.
  • latent_shape (torch.Size) – Shape for latent processes (batch_shape of \(q(f)\)). By default, it equals to output batch shape y.shape[:-1]. For the multi-class classification problems, latent_shape[-1] should corresponse to the number of classes.
  • whiten (bool) – A flag to tell if variational parameters f_loc and f_scale_tril are transformed by the inverse of Lff, where Lff is the lower triangular decomposition of \(kernel(X, X)\). Enable this flag will help optimization.
  • jitter (float) – A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition.
model()[source]
guide()[source]
forward(Xnew, full_cov=False)[source]

Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data \(X_{new}\):

\[p(f^* \mid X_{new}, X, y, k, f_{loc}, f_{scale\_tril}) = \mathcal{N}(loc, cov).\]

Note

Variational parameters f_loc, f_scale_tril, together with kernel’s parameters have been learned from a training procedure (MCMC or SVI).

Parameters:
  • Xnew (torch.Tensor) – A input data for testing. Note that Xnew.shape[1:] must be the same as self.X.shape[1:].
  • full_cov (bool) – A flag to decide if we want to predict full covariance matrix or just variance.
Returns:

loc and covariance matrix (or variance) of \(p(f^*(X_{new}))\)

Return type:

tuple(torch.Tensor, torch.Tensor)

VariationalSparseGP

class VariationalSparseGP(X, y, kernel, Xu, likelihood, mean_function=None, latent_shape=None, num_data=None, whiten=False, jitter=1e-06)[source]

Bases: pyro.contrib.gp.models.model.GPModel

Variational Sparse Gaussian Process model.

In VariationalGP model, when the number of input data \(X\) is large, the covariance matrix \(k(X, X)\) will require a lot of computational steps to compute its inverse (for log likelihood and for prediction). This model introduces an additional inducing-input parameter \(X_u\) to solve that problem. Given inputs \(X\), their noisy observations \(y\), and the inducing-input parameters \(X_u\), the model takes the form:

\[\begin{split}[f, u] &\sim \mathcal{GP}(0, k([X, X_u], [X, X_u])),\\ y & \sim p(y) = p(y \mid f) p(f),\end{split}\]

where \(p(y \mid f)\) is the likelihood.

We will use a variational approach in this model by approximating \(q(f,u)\) to the posterior \(p(f,u \mid y)\). Precisely, \(q(f) = p(f\mid u)q(u)\), where \(q(u)\) is a multivariate normal distribution with two parameters u_loc and u_scale_tril, which will be learned during a variational inference process.

Note

This model can be learned using MCMC method as in reference [2]. See also GPModel.

Note

This model has \(\mathcal{O}(NM^2)\) complexity for training, \(\mathcal{O}(M^3)\) complexity for testing. Here, \(N\) is the number of train inputs, \(M\) is the number of inducing inputs. Size of variational parameters is \(\mathcal{O}(M^2)\).

References:

[1] Scalable variational Gaussian process classification, James Hensman, Alexander G. de G. Matthews, Zoubin Ghahramani

[2] MCMC for Variationally Sparse Gaussian Processes, James Hensman, Alexander G. de G. Matthews, Maurizio Filippone, Zoubin Ghahramani

Parameters:
  • X (torch.Tensor) – A input data for training. Its first dimension is the number of data points.
  • y (torch.Tensor) – An output data for training. Its last dimension is the number of data points.
  • kernel (Kernel) – A Pyro kernel object, which is the covariance function \(k\).
  • Xu (torch.Tensor) – Initial values for inducing points, which are parameters of our model.
  • Likelihood likelihood (likelihood) – A likelihood object.
  • mean_function (callable) – An optional mean function \(m\) of this Gaussian process. By default, we use zero mean.
  • latent_shape (torch.Size) – Shape for latent processes (batch_shape of \(q(u)\)). By default, it equals to output batch shape y.shape[:-1]. For the multi-class classification problems, latent_shape[-1] should corresponse to the number of classes.
  • num_data (int) – The size of full training dataset. It is useful for training this model with mini-batch.
  • whiten (bool) – A flag to tell if variational parameters u_loc and u_scale_tril are transformed by the inverse of Luu, where Luu is the lower triangular decomposition of \(kernel(X_u, X_u)\). Enable this flag will help optimization.
  • jitter (float) – A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition.
model()[source]
guide()[source]
forward(Xnew, full_cov=False)[source]

Computes the mean and covariance matrix (or variance) of Gaussian Process posterior on a test input data \(X_{new}\):

\[p(f^* \mid X_{new}, X, y, k, X_u, u_{loc}, u_{scale\_tril}) = \mathcal{N}(loc, cov).\]

Note

Variational parameters u_loc, u_scale_tril, the inducing-point parameter Xu, together with kernel’s parameters have been learned from a training procedure (MCMC or SVI).

Parameters:
  • Xnew (torch.Tensor) – A input data for testing. Note that Xnew.shape[1:] must be the same as self.X.shape[1:].
  • full_cov (bool) – A flag to decide if we want to predict full covariance matrix or just variance.
Returns:

loc and covariance matrix (or variance) of \(p(f^*(X_{new}))\)

Return type:

tuple(torch.Tensor, torch.Tensor)

GPLVM

class GPLVM(base_model)[source]

Bases: pyro.contrib.gp.parameterized.Parameterized

Gaussian Process Latent Variable Model (GPLVM) model.

GPLVM is a Gaussian Process model with its train input data is a latent variable. This model is useful for dimensional reduction of high dimensional data. Assume the mapping from low dimensional latent variable to is a Gaussian Process instance. Then the high dimensional data will play the role of train output y and our target is to learn latent inputs which best explain y. For the purpose of dimensional reduction, latent inputs should have lower dimensions than y.

We follows reference [1] to put a unit Gaussian prior to the input and approximate its posterior by a multivariate normal distribution with two variational parameters: X_loc and X_scale_tril.

For example, we can do dimensional reduction on Iris dataset as follows:

>>> # With y as the 2D Iris data of shape 150x4 and we want to reduce its dimension
>>> # to a tensor X of shape 150x2, we will use GPLVM.
>>> # First, define the initial values for X parameter:
>>> X_init = torch.zeros(150, 2)
>>> # Then, define a Gaussian Process model with input X_init and output y:
>>> kernel = gp.kernels.RBF(input_dim=2, lengthscale=torch.ones(2))
>>> Xu = torch.zeros(20, 2)  # initial inducing inputs of sparse model
>>> gpmodule = gp.models.SparseGPRegression(X_init, y, kernel, Xu)
>>> # Finally, wrap gpmodule by GPLVM, optimize, and get the "learned" mean of X:
>>> gplvm = gp.models.GPLVM(gpmodule)
>>> gp.util.train(gplvm)  # doctest: +SKIP
>>> X = gplvm.X

Reference:

[1] Bayesian Gaussian Process Latent Variable Model Michalis K. Titsias, Neil D. Lawrence

Parameters:base_model (GPModel) – A Pyro Gaussian Process model object. Note that base_model.X will be the initial value for the variational parameter X_loc.
model()[source]
guide()[source]
forward(**kwargs)[source]

Forward method has the same signal as its base_model. Note that the train input data of base_model is sampled from GPLVM.

Kernels

Kernel

class Kernel(input_dim, active_dims=None)[source]

Bases: pyro.contrib.gp.parameterized.Parameterized

Base class for kernels used in this Gaussian Process module.

Every inherited class should implement a forward() pass which takes inputs \(X\), \(Z\) and returns their covariance matrix.

To construct a new kernel from the old ones, we can use methods add(), mul(), exp(), warp(), vertical_scale().

References:

[1] Gaussian Processes for Machine Learning, Carl E. Rasmussen, Christopher K. I. Williams

Parameters:
  • input_dim (int) – Number of feature dimensions of inputs.
  • variance (torch.Tensor) – Variance parameter of this kernel.
  • active_dims (list) – List of feature dimensions of the input which the kernel acts on.
forward(X, Z=None, diag=False)[source]

Calculates covariance matrix of inputs on active dimensionals.

Parameters:
  • X (torch.Tensor) – A 2D tensor with shape \(N \times input\_dim\).
  • Z (torch.Tensor) – An (optional) 2D tensor with shape \(M \times input\_dim\).
  • diag (bool) – A flag to decide if we want to return full covariance matrix or just its diagonal part.
Returns:

covariance matrix of \(X\) and \(Z\) with shape \(N \times M\)

Return type:

torch.Tensor

Brownian

class Brownian(input_dim, variance=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.kernel.Kernel

This kernel correponds to a two-sided Brownion motion (Wiener process):

\(k(x,z)=\begin{cases}\sigma^2\min(|x|,|z|),& \text{if } x\cdot z\ge 0\\ 0, & \text{otherwise}. \end{cases}\)

Note that the input dimension of this kernel must be 1.

Reference:

[1] Theory and Statistical Applications of Stochastic Processes, Yuliya Mishura, Georgiy Shevchenko

forward(X, Z=None, diag=False)[source]

Combination

class Combination(kern0, kern1)[source]

Bases: pyro.contrib.gp.kernels.kernel.Kernel

Base class for kernels derived from a combination of kernels.

Parameters:

Constant

class Constant(input_dim, variance=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.kernel.Kernel

Implementation of Constant kernel:

\(k(x, z) = \sigma^2.\)
forward(X, Z=None, diag=False)[source]

Coregionalize

class Coregionalize(input_dim, rank=None, components=None, diagonal=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.kernel.Kernel

A kernel for the linear model of coregionalization \(k(x,z) = x^T (W W^T + D) z\) where \(W\) is an input_dim-by-rank matrix and typically rank < input_dim, and D is a diagonal matrix.

This generalizes the Linear kernel to multiple features with a low-rank-plus-diagonal weight matrix. The typical use case is for modeling correlations among outputs of a multi-output GP, where outputs are coded as distinct data points with one-hot coded features denoting which output each datapoint represents.

If only rank is specified, the kernel (W W^T + D) will be randomly initialized to a matrix with expected value the identity matrix.

References:

[1] Mauricio A. Alvarez, Lorenzo Rosasco, Neil D. Lawrence (2012)
Kernels for Vector-Valued Functions: a Review
Parameters:
  • input_dim (int) – Number of feature dimensions of inputs.
  • rank (int) – Optional rank. This is only used if components is unspecified. If neigher rank nor components is specified, then rank defaults to input_dim.
  • components (torch.Tensor) – An optional (input_dim, rank) shaped matrix that maps features to rank-many components. If unspecified, this will be randomly initialized.
  • diagonal (torch.Tensor) – An optional vector of length input_dim. If unspecified, this will be set to constant 0.5.
  • active_dims (list) – List of feature dimensions of the input which the kernel acts on.
  • name (str) – Name of the kernel.
forward(X, Z=None, diag=False)[source]

Cosine

class Cosine(input_dim, variance=None, lengthscale=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.isotropic.Isotropy

Implementation of Cosine kernel:

\(k(x,z) = \sigma^2 \cos\left(\frac{|x-z|}{l}\right).\)
Parameters:lengthscale (torch.Tensor) – Length-scale parameter of this kernel.
forward(X, Z=None, diag=False)[source]

DotProduct

class DotProduct(input_dim, variance=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.kernel.Kernel

Base class for kernels which are functions of \(x \cdot z\).

Exponent

class Exponent(kern)[source]

Bases: pyro.contrib.gp.kernels.kernel.Transforming

Creates a new kernel according to

\(k_{new}(x, z) = \exp(k(x, z)).\)
forward(X, Z=None, diag=False)[source]

Exponential

class Exponential(input_dim, variance=None, lengthscale=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.isotropic.Isotropy

Implementation of Exponential kernel:

\(k(x, z) = \sigma^2\exp\left(-\frac{|x-z|}{l}\right).\)
forward(X, Z=None, diag=False)[source]

Isotropy

class Isotropy(input_dim, variance=None, lengthscale=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.kernel.Kernel

Base class for a family of isotropic covariance kernels which are functions of the distance \(|x-z|/l\), where \(l\) is the length-scale parameter.

By default, the parameter lengthscale has size 1. To use the isotropic version (different lengthscale for each dimension), make sure that lengthscale has size equal to input_dim.

Parameters:lengthscale (torch.Tensor) – Length-scale parameter of this kernel.

Linear

class Linear(input_dim, variance=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.dot_product.DotProduct

Implementation of Linear kernel:

\(k(x, z) = \sigma^2 x \cdot z.\)

Doing Gaussian Process regression with linear kernel is equivalent to doing a linear regression.

Note

Here we implement the homogeneous version. To use the inhomogeneous version, consider using Polynomial kernel with degree=1 or making a Sum with a Constant kernel.

forward(X, Z=None, diag=False)[source]

Matern32

class Matern32(input_dim, variance=None, lengthscale=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.isotropic.Isotropy

Implementation of Matern32 kernel:

\(k(x, z) = \sigma^2\left(1 + \sqrt{3} \times \frac{|x-z|}{l}\right) \exp\left(-\sqrt{3} \times \frac{|x-z|}{l}\right).\)
forward(X, Z=None, diag=False)[source]

Matern52

class Matern52(input_dim, variance=None, lengthscale=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.isotropic.Isotropy

Implementation of Matern52 kernel:

\(k(x,z)=\sigma^2\left(1+\sqrt{5}\times\frac{|x-z|}{l}+\frac{5}{3}\times \frac{|x-z|^2}{l^2}\right)\exp\left(-\sqrt{5} \times \frac{|x-z|}{l}\right).\)
forward(X, Z=None, diag=False)[source]

Periodic

class Periodic(input_dim, variance=None, lengthscale=None, period=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.kernel.Kernel

Implementation of Periodic kernel:

\(k(x,z)=\sigma^2\exp\left(-2\times\frac{\sin^2(\pi(x-z)/p)}{l^2}\right),\)

where \(p\) is the period parameter.

References:

[1] Introduction to Gaussian processes, David J.C. MacKay

Parameters:
  • lengthscale (torch.Tensor) – Length scale parameter of this kernel.
  • period (torch.Tensor) – Period parameter of this kernel.
forward(X, Z=None, diag=False)[source]

Polynomial

class Polynomial(input_dim, variance=None, bias=None, degree=1, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.dot_product.DotProduct

Implementation of Polynomial kernel:

\(k(x, z) = \sigma^2(\text{bias} + x \cdot z)^d.\)
Parameters:
  • bias (torch.Tensor) – Bias parameter of this kernel. Should be positive.
  • degree (int) – Degree \(d\) of the polynomial.
forward(X, Z=None, diag=False)[source]

Product

class Product(kern0, kern1)[source]

Bases: pyro.contrib.gp.kernels.kernel.Combination

Returns a new kernel which acts like a product/tensor product of two kernels. The second kernel can be a constant.

forward(X, Z=None, diag=False)[source]

RBF

class RBF(input_dim, variance=None, lengthscale=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.isotropic.Isotropy

Implementation of Radial Basis Function kernel:

\(k(x,z) = \sigma^2\exp\left(-0.5 \times \frac{|x-z|^2}{l^2}\right).\)

Note

This kernel also has name Squared Exponential in literature.

forward(X, Z=None, diag=False)[source]

RationalQuadratic

class RationalQuadratic(input_dim, variance=None, lengthscale=None, scale_mixture=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.isotropic.Isotropy

Implementation of RationalQuadratic kernel:

\(k(x, z) = \sigma^2 \left(1 + 0.5 \times \frac{|x-z|^2}{\alpha l^2} \right)^{-\alpha}.\)
Parameters:scale_mixture (torch.Tensor) – Scale mixture (\(\alpha\)) parameter of this kernel. Should have size 1.
forward(X, Z=None, diag=False)[source]

Sum

class Sum(kern0, kern1)[source]

Bases: pyro.contrib.gp.kernels.kernel.Combination

Returns a new kernel which acts like a sum/direct sum of two kernels. The second kernel can be a constant.

forward(X, Z=None, diag=False)[source]

Transforming

class Transforming(kern)[source]

Bases: pyro.contrib.gp.kernels.kernel.Kernel

Base class for kernels derived from a kernel by some transforms such as warping, exponent, vertical scaling.

Parameters:kern (Kernel) – The original kernel.

VerticalScaling

class VerticalScaling(kern, vscaling_fn)[source]

Bases: pyro.contrib.gp.kernels.kernel.Transforming

Creates a new kernel according to

\(k_{new}(x, z) = f(x)k(x, z)f(z),\)

where \(f\) is a function.

Parameters:vscaling_fn (callable) – A vertical scaling function \(f\).
forward(X, Z=None, diag=False)[source]

Warping

class Warping(kern, iwarping_fn=None, owarping_coef=None)[source]

Bases: pyro.contrib.gp.kernels.kernel.Transforming

Creates a new kernel according to

\(k_{new}(x, z) = q(k(f(x), f(z))),\)

where \(f\) is an function and \(q\) is a polynomial with non-negative coefficients owarping_coef.

We can take advantage of \(f\) to combine a Gaussian Process kernel with a deep learning architecture. For example:

>>> linear = torch.nn.Linear(10, 3)
>>> # register its parameters to Pyro's ParamStore and wrap it by lambda
>>> # to call the primitive pyro.module each time we use the linear function
>>> pyro_linear_fn = lambda x: pyro.module("linear", linear)(x)
>>> kernel = gp.kernels.Matern52(input_dim=3, lengthscale=torch.ones(3))
>>> warped_kernel = gp.kernels.Warping(kernel, pyro_linear_fn)

Reference:

[1] Deep Kernel Learning, Andrew G. Wilson, Zhiting Hu, Ruslan Salakhutdinov, Eric P. Xing

Parameters:
  • iwarping_fn (callable) – An input warping function \(f\).
  • owarping_coef (list) – A list of coefficients of the output warping polynomial. These coefficients must be non-negative.
forward(X, Z=None, diag=False)[source]

WhiteNoise

class WhiteNoise(input_dim, variance=None, active_dims=None)[source]

Bases: pyro.contrib.gp.kernels.kernel.Kernel

Implementation of WhiteNoise kernel:

\(k(x, z) = \sigma^2 \delta(x, z),\)

where \(\delta\) is a Dirac delta function.

forward(X, Z=None, diag=False)[source]

Likelihoods

Likelihood

class Likelihood[source]

Bases: pyro.contrib.gp.parameterized.Parameterized

Base class for likelihoods used in Gaussian Process.

Every inherited class should implement a forward pass which takes an input \(f\) and returns a sample \(y\).

forward(f_loc, f_var, y=None)[source]

Samples \(y\) given \(f_{loc}\), \(f_{var}\).

Parameters:
Returns:

a tensor sampled from likelihood

Return type:

torch.Tensor

Binary

class Binary(response_function=None)[source]

Bases: pyro.contrib.gp.likelihoods.likelihood.Likelihood

Implementation of Binary likelihood, which is used for binary classification problems.

Binary likelihood uses Bernoulli distribution, so the output of response_function should be in range \((0,1)\). By default, we use sigmoid function.

Parameters:response_function (callable) – A mapping to correct domain for Binary likelihood.
forward(f_loc, f_var, y=None)[source]

Samples \(y\) given \(f_{loc}\), \(f_{var}\) according to

\[\begin{split}f & \sim \mathbb{Normal}(f_{loc}, f_{var}),\\ y & \sim \mathbb{Bernoulli}(f).\end{split}\]

Note

The log likelihood is estimated using Monte Carlo with 1 sample of \(f\).

Parameters:
Returns:

a tensor sampled from likelihood

Return type:

torch.Tensor

Gaussian

class Gaussian(variance=None)[source]

Bases: pyro.contrib.gp.likelihoods.likelihood.Likelihood

Implementation of Gaussian likelihood, which is used for regression problems.

Gaussian likelihood uses Normal distribution.

Parameters:variance (torch.Tensor) – A variance parameter, which plays the role of noise in regression problems.
forward(f_loc, f_var, y=None)[source]

Samples \(y\) given \(f_{loc}\), \(f_{var}\) according to

\[y \sim \mathbb{Normal}(f_{loc}, f_{var} + \epsilon),\]

where \(\epsilon\) is the variance parameter of this likelihood.

Parameters:
Returns:

a tensor sampled from likelihood

Return type:

torch.Tensor

MultiClass

class MultiClass(num_classes, response_function=None)[source]

Bases: pyro.contrib.gp.likelihoods.likelihood.Likelihood

Implementation of MultiClass likelihood, which is used for multi-class classification problems.

MultiClass likelihood uses Categorical distribution, so response_function should normalize its input’s rightmost axis. By default, we use softmax function.

Parameters:
  • num_classes (int) – Number of classes for prediction.
  • response_function (callable) – A mapping to correct domain for MultiClass likelihood.
forward(f_loc, f_var, y=None)[source]

Samples \(y\) given \(f_{loc}\), \(f_{var}\) according to

\[\begin{split}f & \sim \mathbb{Normal}(f_{loc}, f_{var}),\\ y & \sim \mathbb{Categorical}(f).\end{split}\]

Note

The log likelihood is estimated using Monte Carlo with 1 sample of \(f\).

Parameters:
Returns:

a tensor sampled from likelihood

Return type:

torch.Tensor

Poisson

class Poisson(response_function=None)[source]

Bases: pyro.contrib.gp.likelihoods.likelihood.Likelihood

Implementation of Poisson likelihood, which is used for count data.

Poisson likelihood uses the Poisson distribution, so the output of response_function should be positive. By default, we use torch.exp() as response function, corresponding to a log-Gaussian Cox process.

Parameters:response_function (callable) – A mapping to positive real numbers.
forward(f_loc, f_var, y=None)[source]

Samples \(y\) given \(f_{loc}\), \(f_{var}\) according to

\[\begin{split}f & \sim \mathbb{Normal}(f_{loc}, f_{var}),\\ y & \sim \mathbb{Poisson}(\exp(f)).\end{split}\]

Note

The log likelihood is estimated using Monte Carlo with 1 sample of \(f\).

Parameters:
Returns:

a tensor sampled from likelihood

Return type:

torch.Tensor

Parameterized

class Parameterized[source]

Bases: pyro.nn.module.PyroModule

A wrapper of PyroModule whose parameters can be set constraints, set priors.

By default, when we set a prior to a parameter, an auto Delta guide will be created. We can use the method autoguide() to setup other auto guides.

Example:

>>> class Linear(Parameterized):
...     def __init__(self, a, b):
...         super().__init__()
...         self.a = Parameter(a)
...         self.b = Parameter(b)
...
...     def forward(self, x):
...         return self.a * x + self.b
...
>>> linear = Linear(torch.tensor(1.), torch.tensor(0.))
>>> linear.a = PyroParam(torch.tensor(1.), constraints.positive)
>>> linear.b = PyroSample(dist.Normal(0, 1))
>>> linear.autoguide("b", dist.Normal)
>>> assert "a_unconstrained" in dict(linear.named_parameters())
>>> assert "b_loc" in dict(linear.named_parameters())
>>> assert "b_scale_unconstrained" in dict(linear.named_parameters())

Note that by default, data of a parameter is a float torch.Tensor (unless we use torch.set_default_tensor_type() to change default tensor type). To cast these parameters to a correct data type or GPU device, we can call methods such as double() or cuda(). See torch.nn.Module for more information.

set_prior(name, prior)[source]

Sets prior for a parameter.

Parameters:
  • name (str) – Name of the parameter.
  • prior (Distribution) – A Pyro prior distribution.
autoguide(name, dist_constructor)[source]

Sets an autoguide for an existing parameter with name name (mimic the behavior of module pyro.infer.autoguide).

Note

dist_constructor should be one of Delta, Normal, and MultivariateNormal. More distribution constructor will be supported in the future if needed.

Parameters:
  • name (str) – Name of the parameter.
  • dist_constructor – A Distribution constructor.
set_mode(mode)[source]

Sets mode of this object to be able to use its parameters in stochastic functions. If mode="model", a parameter will get its value from its prior. If mode="guide", the value will be drawn from its guide.

Note

This method automatically sets mode for submodules which belong to Parameterized class.

Parameters:mode (str) – Either “model” or “guide”.
mode

Util

conditional(Xnew, X, kernel, f_loc, f_scale_tril=None, Lff=None, full_cov=False, whiten=False, jitter=1e-06)[source]

Given \(X_{new}\), predicts loc and covariance matrix of the conditional multivariate normal distribution

\[p(f^*(X_{new}) \mid X, k, f_{loc}, f_{scale\_tril}).\]

Here f_loc and f_scale_tril are variation parameters of the variational distribution

\[q(f \mid f_{loc}, f_{scale\_tril}) \sim p(f | X, y),\]

where \(f\) is the function value of the Gaussian Process given input \(X\)

\[p(f(X)) \sim \mathcal{N}(0, k(X, X))\]

and \(y\) is computed from \(f\) by some likelihood function \(p(y|f)\).

In case f_scale_tril=None, we consider \(f = f_{loc}\) and computes

\[p(f^*(X_{new}) \mid X, k, f).\]

In case f_scale_tril is not None, we follow the derivation from reference [1]. For the case f_scale_tril=None, we follow the popular reference [2].

References:

[1] Sparse GPs: approximate the posterior, not the model

[2] Gaussian Processes for Machine Learning, Carl E. Rasmussen, Christopher K. I. Williams

Parameters:
  • Xnew (torch.Tensor) – A new input data.
  • X (torch.Tensor) – An input data to be conditioned on.
  • kernel (Kernel) – A Pyro kernel object.
  • f_loc (torch.Tensor) – Mean of \(q(f)\). In case f_scale_tril=None, \(f_{loc} = f\).
  • f_scale_tril (torch.Tensor) – Lower triangular decomposition of covariance matrix of \(q(f)\)’s .
  • Lff (torch.Tensor) – Lower triangular decomposition of \(kernel(X, X)\) (optional).
  • full_cov (bool) – A flag to decide if we want to return full covariance matrix or just variance.
  • whiten (bool) – A flag to tell if f_loc and f_scale_tril are already transformed by the inverse of Lff.
  • jitter (float) – A small positive term which is added into the diagonal part of a covariance matrix to help stablize its Cholesky decomposition.
Returns:

loc and covariance matrix (or variance) of \(p(f^*(X_{new}))\)

Return type:

tuple(torch.Tensor, torch.Tensor)

train(gpmodule, optimizer=None, loss_fn=None, retain_graph=None, num_steps=1000)[source]

A helper to optimize parameters for a GP module.

Parameters:
  • gpmodule (GPModel) – A GP module.
  • optimizer (Optimizer) – A PyTorch optimizer instance. By default, we use Adam with lr=0.01.
  • loss_fn (callable) – A loss function which takes inputs are gpmodule.model, gpmodule.guide, and returns ELBO loss. By default, loss_fn=TraceMeanField_ELBO().differentiable_loss.
  • retain_graph (bool) – An optional flag of torch.autograd.backward.
  • num_steps (int) – Number of steps to run SVI.
Returns:

a list of losses during the training procedure

Return type:

list