Source code for pyro.distributions.ordered_logistic

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

from pyro.distributions import constraints
from pyro.distributions.torch import Categorical


[docs]class OrderedLogistic(Categorical): """ Alternative parametrization of the distribution over a categorical variable. Instead of the typical parametrization of a categorical variable in terms of the probability mass of the individual categories ``p``, this provides an alternative that is useful in specifying ordered categorical models. This accepts a vector of ``cutpoints`` which are an ordered vector of real numbers denoting baseline cumulative log-odds of the individual categories, and a model vector ``predictor`` which modifies the baselines for each sample individually. These cumulative log-odds are then transformed into a discrete cumulative probability distribution, that is finally differenced to return the probability mass matrix ``p`` that specifies the categorical distribution. :param Tensor predictor: A tensor of predictor variables of arbitrary shape. The output shape of non-batched samples from this distribution will be the same shape as ``predictor``. :param Tensor cutpoints: A tensor of cutpoints that are used to determine the cumulative probability of each entry in ``predictor`` belonging to a given category. The first `cutpoints.ndim-1` dimensions must be broadcastable to ``predictor``, and the -1 dimension is monotonically increasing. """ arg_constraints = { "predictor": constraints.real, "cutpoints": constraints.ordered_vector, } def __init__(self, predictor, cutpoints, validate_args=None): # calculate cumulative probability for each predictor q = torch.sigmoid(cutpoints - predictor.unsqueeze(-1)) # expand parameters to match batch shape self.predictor = predictor.expand(q.shape[:-1]) self.cutpoints = cutpoints.expand(q.shape) # turn cumulative probabilities into probability mass of categories p_shape = q.shape[:-1] + (q.shape[-1] + 1,) p = torch.zeros(p_shape, dtype=q.dtype, device=q.device) p[..., 0] = q[..., 0] p[..., 1:-1] = q[..., 1:] - q[..., :-1] p[..., -1] = 1 - q[..., -1] # pass probability mass to Categorical constructor super(OrderedLogistic, self).__init__(p, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) new = self._get_checked_instance(OrderedLogistic, _instance) new.predictor = self.predictor.expand(batch_shape) new.cutpoints = self.cutpoints.expand(batch_shape + (self.cutpoints.shape[-1],)) return super(OrderedLogistic, self).expand(batch_shape, new)