Source code for pyro.contrib.gp.kernels.static

import torch
from torch.distributions import constraints

from pyro.contrib.gp.kernels.kernel import Kernel
from pyro.nn.module import PyroParam


[docs]class Constant(Kernel): r""" Implementation of Constant kernel: :math:`k(x, z) = \sigma^2.` """ def __init__(self, input_dim, variance=None, active_dims=None): super(Constant, self).__init__(input_dim, active_dims) variance = torch.tensor(1.) if variance is None else variance self.variance = PyroParam(variance, constraints.positive)
[docs] def forward(self, X, Z=None, diag=False): if diag: return self.variance.expand(X.size(0)) if Z is None: Z = X return self.variance.expand(X.size(0), Z.size(0))
[docs]class WhiteNoise(Kernel): r""" Implementation of WhiteNoise kernel: :math:`k(x, z) = \sigma^2 \delta(x, z),` where :math:`\delta` is a Dirac delta function. """ def __init__(self, input_dim, variance=None, active_dims=None): super(WhiteNoise, self).__init__(input_dim, active_dims) variance = torch.tensor(1.) if variance is None else variance self.variance = PyroParam(variance, constraints.positive)
[docs] def forward(self, X, Z=None, diag=False): if diag: return self.variance.expand(X.size(0)) if Z is None: return self.variance.expand(X.size(0)).diag() else: return X.data.new_zeros(X.size(0), Z.size(0))