Source code for pyro.contrib.mue.missingdatahmm

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

from pyro.distributions import constraints
from pyro.distributions.hmm import _sequential_logmatmulexp
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape


[docs]class MissingDataDiscreteHMM(TorchDistribution): """ HMM with discrete latent states and discrete observations, allowing for missing data or variable length sequences. Observations are assumed to be one hot encoded; rows with all zeros indicate missing data. .. warning:: Unlike in pyro's pyro.distributions.DiscreteHMM, which computes the probability of the first state as initial.T @ transition @ emission this distribution uses the standard HMM convention, initial.T @ emission :param ~torch.Tensor initial_logits: A logits tensor for an initial categorical distribution over latent states. Should have rightmost size ``state_dim`` and be broadcastable to ``(batch_size, state_dim)``. :param ~torch.Tensor transition_logits: A logits tensor for transition conditional distributions between latent states. Should have rightmost shape ``(state_dim, state_dim)`` (old, new), and be broadcastable to ``(batch_size, state_dim, state_dim)``. :param ~torch.Tensor observation_logits: A logits tensor for observation distributions from latent states. Should have rightmost shape ``(state_dim, categorical_size)``, where ``categorical_size`` is the dimension of the categorical output, and be broadcastable to ``(batch_size, state_dim, categorical_size)``. """ arg_constraints = { "initial_logits": constraints.real_vector, "transition_logits": constraints.independent(constraints.real, 2), "observation_logits": constraints.independent(constraints.real, 2), } support = constraints.independent(constraints.nonnegative_integer, 2) def __init__( self, initial_logits, transition_logits, observation_logits, validate_args=None ): if initial_logits.dim() < 1: raise ValueError( "expected initial_logits to have at least one dim, " "actual shape = {}".format(initial_logits.shape) ) if transition_logits.dim() < 2: raise ValueError( "expected transition_logits to have at least two dims, " "actual shape = {}".format(transition_logits.shape) ) if observation_logits.dim() < 2: raise ValueError( "expected observation_logits to have at least two dims, " "actual shape = {}".format(transition_logits.shape) ) shape = broadcast_shape( initial_logits.shape[:-1], transition_logits.shape[:-2], observation_logits.shape[:-2], ) if len(shape) == 0: shape = torch.Size([1]) batch_shape = shape event_shape = (1, observation_logits.shape[-1]) self.initial_logits = initial_logits - initial_logits.logsumexp(-1, True) self.transition_logits = transition_logits - transition_logits.logsumexp( -1, True ) self.observation_logits = observation_logits - observation_logits.logsumexp( -1, True ) super(MissingDataDiscreteHMM, self).__init__( batch_shape, event_shape, validate_args=validate_args )
[docs] def log_prob(self, value): """ :param ~torch.Tensor value: One-hot encoded observation. Must be real-valued (float) and broadcastable to ``(batch_size, num_steps, categorical_size)`` where ``categorical_size`` is the dimension of the categorical output. Missing data is represented by zeros, i.e. ``value[batch, step, :] == tensor([0, ..., 0])``. Variable length observation sequences can be handled by padding the sequence with zeros at the end. """ assert value.shape[-1] == self.event_shape[1] # Combine observation and transition factors. value_logits = torch.matmul( value, torch.transpose(self.observation_logits, -2, -1) ) result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :] # Eliminate time dimension. result = _sequential_logmatmulexp(result) # Combine initial factor. result = self.initial_logits + value_logits[..., 0, :] + result.logsumexp(-1) # Marginalize out final state. result = result.logsumexp(-1) return result