Source code for pyro.contrib.gp.kernels.static

from __future__ import absolute_import, division, print_function

import torch
from torch.distributions import constraints
from torch.nn import Parameter

from .kernel import Kernel


[docs]class Constant(Kernel): """ Implementation of Constant kernel: :math:`k(x, z) = \sigma^2.` """ def __init__(self, input_dim, variance=None, active_dims=None, name="Constant"): super(Constant, self).__init__(input_dim, active_dims, name) if variance is None: variance = torch.tensor(1.) self.variance = Parameter(variance) self.set_constraint("variance", constraints.positive)
[docs] def forward(self, X, Z=None, diag=False): variance = self.get_param("variance") if diag: return variance.expand(X.shape[0]) if Z is None: Z = X return variance.expand(X.shape[0], Z.shape[0])
[docs]class WhiteNoise(Kernel): """ Implementation of WhiteNoise kernel: :math:`k(x, z) = \sigma^2 \delta(x, z),` where :math:`\delta` is a Dirac delta function. """ def __init__(self, input_dim, variance=None, active_dims=None, name="WhiteNoise"): super(WhiteNoise, self).__init__(input_dim, active_dims, name) if variance is None: variance = torch.tensor(1.) self.variance = Parameter(variance) self.set_constraint("variance", constraints.positive)
[docs] def forward(self, X, Z=None, diag=False): variance = self.get_param("variance") if diag: return variance.expand(X.shape[0]) if Z is None: return variance.expand(X.shape[0]).diag() else: return X.data.new_zeros(X.shape[0], Z.shape[0])