Source code for pyro.contrib.gp.kernels.periodic

from __future__ import absolute_import, division, print_function

import math

import torch
from torch.distributions import constraints
from torch.nn import Parameter

from .isotropic import Isotropy
from .kernel import Kernel


[docs]class Cosine(Isotropy): r""" Implementation of Cosine kernel: :math:`k(x,z) = \sigma^2 \cos\left(\frac{|x-z|}{l}\right).` :param torch.Tensor lengthscale: Length-scale parameter of this kernel. """ def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name="Cosine"): super(Cosine, self).__init__(input_dim, variance, lengthscale, active_dims, name)
[docs] def forward(self, X, Z=None, diag=False): if diag: return self._diag(X) variance = self.get_param("variance") r = self._scaled_dist(X, Z) return variance * torch.cos(r)
[docs]class Periodic(Kernel): r""" Implementation of Periodic kernel: :math:`k(x,z)=\sigma^2\exp\left(-2\times\frac{\sin^2(\pi(x-z)/p)}{l^2}\right),` where :math:`p` is the ``period`` parameter. References: [1] `Introduction to Gaussian processes`, David J.C. MacKay :param torch.Tensor lengthscale: Length scale parameter of this kernel. :param torch.Tensor period: Period parameter of this kernel. """ def __init__(self, input_dim, variance=None, lengthscale=None, period=None, active_dims=None, name="Periodic"): super(Periodic, self).__init__(input_dim, active_dims, name) if variance is None: variance = torch.tensor(1.) self.variance = Parameter(variance) self.set_constraint("variance", constraints.positive) if lengthscale is None: lengthscale = torch.tensor(1.) self.lengthscale = Parameter(lengthscale) self.set_constraint("lengthscale", constraints.positive) if period is None: period = torch.tensor(1.) self.period = Parameter(period) self.set_constraint("period", constraints.positive)
[docs] def forward(self, X, Z=None, diag=False): if diag: variance = self.get_param("variance") return variance.expand(X.shape[0]) if Z is None: Z = X X = self._slice_input(X) Z = self._slice_input(Z) if X.shape[1] != Z.shape[1]: raise ValueError("Inputs must have the same number of features.") variance = self.get_param("variance") lengthscale = self.get_param("lengthscale") period = self.get_param("period") d = X.unsqueeze(1) - Z.unsqueeze(0) scaled_sin = torch.sin(math.pi * d / period) / lengthscale return variance * torch.exp(-2 * (scaled_sin ** 2).sum(-1))