Source code for pyro.contrib.gp.likelihoods.multi_class

from __future__ import absolute_import, division, print_function

import torch.nn.functional as F

import pyro
import pyro.distributions as dist

from .likelihood import Likelihood


def _softmax(x):
    return F.softmax(x, dim=-1)


[docs]class MultiClass(Likelihood): """ Implementation of MultiClass likelihood, which is used for multi-class classification problems. MultiClass likelihood uses :class:`~pyro.distributions.Categorical` distribution, so ``response_function`` should normalize its input's rightmost axis. By default, we use `softmax` function. :param int num_classes: Number of classes for prediction. :param callable response_function: A mapping to correct domain for MultiClass likelihood. """ def __init__(self, num_classes, response_function=None, name="MultiClass"): super(MultiClass, self).__init__(name) self.num_classes = num_classes self.response_function = (response_function if response_function is not None else _softmax)
[docs] def forward(self, f_loc, f_var, y=None): r""" Samples :math:`y` given :math:`f_{loc}`, :math:`f_{var}` according to .. math:: f & \sim \mathbb{Normal}(f_{loc}, f_{var}),\\ y & \sim \mathbb{Categorical}(f). .. note:: The log likelihood is estimated using Monte Carlo with 1 sample of :math:`f`. :param torch.Tensor f_loc: Mean of latent function output. :param torch.Tensor f_var: Variance of latent function output. :param torch.Tensor y: Training output tensor. :returns: a tensor sampled from likelihood :rtype: torch.Tensor """ # calculates Monte Carlo estimate for E_q(f) [logp(y | f)] f = dist.Normal(f_loc, f_var)() if f.dim() < 2: raise ValueError("Latent function output should have at least 2 " "dimensions: one for number of classes and one for " "number of data.") # swap class dimension and data dimension f_swap = f.transpose(-2, -1) # -> num_data x num_classes if f_swap.shape[-1] != self.num_classes: raise ValueError("Number of Gaussian processes should be equal to the " "number of classes. Expected {} but got {}." .format(self.num_classes, f_swap.shape[-1])) f_res = self.response_function(f_swap) y_dist = dist.Categorical(f_res) if y is not None: y_dist = y_dist.expand_by(y.shape[:-f_res.dim() + 1]).independent(y.dim()) return pyro.sample(self.y_name, y_dist, obs=y)