Source code for pyro.contrib.mue.missingdatahmm

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import Categorical, OneHotCategorical

from pyro.distributions import constraints
from pyro.distributions.hmm import _sequential_logmatmulexp
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape


[docs]class MissingDataDiscreteHMM(TorchDistribution): """ HMM with discrete latent states and discrete observations, allowing for missing data or variable length sequences. Observations are assumed to be one hot encoded; rows with all zeros indicate missing data. .. warning:: Unlike in pyro's pyro.distributions.DiscreteHMM, which computes the probability of the first state as initial.T @ transition @ emission this distribution uses the standard HMM convention, initial.T @ emission :param ~torch.Tensor initial_logits: A logits tensor for an initial categorical distribution over latent states. Should have rightmost size ``state_dim`` and be broadcastable to ``(batch_size, state_dim)``. :param ~torch.Tensor transition_logits: A logits tensor for transition conditional distributions between latent states. Should have rightmost shape ``(state_dim, state_dim)`` (old, new), and be broadcastable to ``(batch_size, state_dim, state_dim)``. :param ~torch.Tensor observation_logits: A logits tensor for observation distributions from latent states. Should have rightmost shape ``(state_dim, categorical_size)``, where ``categorical_size`` is the dimension of the categorical output, and be broadcastable to ``(batch_size, state_dim, categorical_size)``. """ arg_constraints = { "initial_logits": constraints.real_vector, "transition_logits": constraints.independent(constraints.real, 2), "observation_logits": constraints.independent(constraints.real, 2), } support = constraints.independent(constraints.nonnegative_integer, 2) def __init__( self, initial_logits, transition_logits, observation_logits, validate_args=None ): if initial_logits.dim() < 1: raise ValueError( "expected initial_logits to have at least one dim, " "actual shape = {}".format(initial_logits.shape) ) if transition_logits.dim() < 2: raise ValueError( "expected transition_logits to have at least two dims, " "actual shape = {}".format(transition_logits.shape) ) if observation_logits.dim() < 2: raise ValueError( "expected observation_logits to have at least two dims, " "actual shape = {}".format(transition_logits.shape) ) shape = broadcast_shape( initial_logits.shape[:-1], transition_logits.shape[:-2], observation_logits.shape[:-2], ) if len(shape) == 0: shape = torch.Size([1]) batch_shape = shape event_shape = (1, observation_logits.shape[-1]) self.initial_logits = initial_logits - initial_logits.logsumexp(-1, True) self.transition_logits = transition_logits - transition_logits.logsumexp( -1, True ) self.observation_logits = observation_logits - observation_logits.logsumexp( -1, True ) super(MissingDataDiscreteHMM, self).__init__( batch_shape, event_shape, validate_args=validate_args )
[docs] def log_prob(self, value): """ :param ~torch.Tensor value: One-hot encoded observation. Must be real-valued (float) and broadcastable to ``(batch_size, num_steps, categorical_size)`` where ``categorical_size`` is the dimension of the categorical output. Missing data is represented by zeros, i.e. ``value[batch, step, :] == tensor([0, ..., 0])``. Variable length observation sequences can be handled by padding the sequence with zeros at the end. """ assert value.shape[-1] == self.event_shape[1] # Combine observation and transition factors. value_logits = torch.matmul( value, torch.transpose(self.observation_logits, -2, -1) ) result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :] # Eliminate time dimension. result = _sequential_logmatmulexp(result) # Combine initial factor. result = self.initial_logits + value_logits[..., 0, :] + result.logsumexp(-1) # Marginalize out final state. result = result.logsumexp(-1) return result
[docs] def sample(self, sample_shape=torch.Size([])): """ :param ~torch.Size sample_shape: Sample shape, last dimension must be ``num_steps`` and must be broadcastable to ``(batch_size, num_steps)``. batch_size must be int not tuple. """ # shape: batch_size x num_steps x categorical_size shape = broadcast_shape( torch.Size(list(self.batch_shape) + [1, 1]), torch.Size(list(sample_shape) + [1]), torch.Size((1, 1, self.event_shape[-1])), ) # state: batch_size x state_dim state = OneHotCategorical(logits=self.initial_logits).sample() # sample: batch_size x num_steps x categorical_size sample = torch.zeros(shape) for i in range(shape[-2]): # batch_size x 1 x state_dim @ # batch_size x state_dim x categorical_size obs_logits = torch.matmul( state.unsqueeze(-2), self.observation_logits ).squeeze(-2) sample[:, i, :] = OneHotCategorical(logits=obs_logits).sample() # batch_size x 1 x state_dim @ # batch_size x state_dim x state_dim trans_logits = torch.matmul( state.unsqueeze(-2), self.transition_logits ).squeeze(-2) state = OneHotCategorical(logits=trans_logits).sample() return sample
[docs] def filter(self, value): """ Compute the marginal probability of the state variable at each step conditional on the previous observations. :param ~torch.Tensor value: One-hot encoded observation. Must be real-valued (float) and broadcastable to ``(batch_size, num_steps, categorical_size)`` where ``categorical_size`` is the dimension of the categorical output. """ # batch_size x num_steps x state_dim shape = broadcast_shape( torch.Size(list(self.batch_shape) + [1, 1]), torch.Size(list(value.shape[:-1]) + [1]), torch.Size((1, 1, self.initial_logits.shape[-1])), ) filter = torch.zeros(shape) # Combine observation and transition factors. # batch_size x num_steps x state_dim value_logits = torch.matmul( value, torch.transpose(self.observation_logits, -2, -1) ) # batch_size x num_steps-1 x state_dim x state_dim result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :] # Forward pass. (This could be parallelized using the # Sarkka & Garcia-Fernandez method.) filter[..., 0, :] = self.initial_logits + value_logits[..., 0, :] filter[..., 0, :] = filter[..., 0, :] - torch.logsumexp( filter[..., 0, :], -1, True ) for i in range(1, shape[-2]): filter[..., i, :] = torch.logsumexp( filter[..., i - 1, :, None] + result[..., i - 1, :, :], -2 ) filter[..., i, :] = filter[..., i, :] - torch.logsumexp( filter[..., i, :], -1, True ) return filter
[docs] def smooth(self, value): """ Compute posterior expected value of state at each position (smoothing). :param ~torch.Tensor value: One-hot encoded observation. Must be real-valued (float) and broadcastable to ``(batch_size, num_steps, categorical_size)`` where ``categorical_size`` is the dimension of the categorical output. """ # Compute filter and initialize. filter = self.filter(value) shape = filter.shape backfilter = torch.zeros(shape) # Combine observation and transition factors. # batch_size x num_steps x state_dim value_logits = torch.matmul( value, torch.transpose(self.observation_logits, -2, -1) ) # batch_size x num_steps-1 x state_dim x state_dim result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :] # Construct backwards filter. for i in range(shape[-2] - 1, 0, -1): backfilter[..., i - 1, :] = torch.logsumexp( backfilter[..., i, None, :] + result[..., i - 1, :, :], -1 ) # Compute smoothed version. smooth = filter + backfilter smooth = smooth - torch.logsumexp(smooth, -1, True) return smooth
[docs] def sample_states(self, value): """ Sample states with forward filtering-backward sampling algorithm. :param ~torch.Tensor value: One-hot encoded observation. Must be real-valued (float) and broadcastable to ``(batch_size, num_steps, categorical_size)`` where ``categorical_size`` is the dimension of the categorical output. """ filter = self.filter(value) shape = filter.shape joint = filter.unsqueeze(-1) + self.transition_logits.unsqueeze(-3) states = torch.zeros(shape[:-1], dtype=torch.long) states[..., -1] = Categorical(logits=filter[..., -1, :]).sample() for i in range(shape[-2] - 1, 0, -1): logits = torch.gather( joint[..., i - 1, :, :], -1, states[..., i, None, None] * torch.ones([shape[-1], 1], dtype=torch.long), ).squeeze(-1) states[..., i - 1] = Categorical(logits=logits).sample() return states
[docs] def map_states(self, value): """ Compute maximum a posteriori (MAP) estimate of state variable with Viterbi algorithm. :param ~torch.Tensor value: One-hot encoded observation. Must be real-valued (float) and broadcastable to ``(batch_size, num_steps, categorical_size)`` where ``categorical_size`` is the dimension of the categorical output. """ # Setup for Viterbi. # batch_size x num_steps x state_dim shape = broadcast_shape( torch.Size(list(self.batch_shape) + [1, 1]), torch.Size(list(value.shape[:-1]) + [1]), torch.Size((1, 1, self.initial_logits.shape[-1])), ) state_logits = torch.zeros(shape) state_traceback = torch.zeros(shape, dtype=torch.long) # Combine observation and transition factors. # batch_size x num_steps x state_dim value_logits = torch.matmul( value, torch.transpose(self.observation_logits, -2, -1) ) # batch_size x num_steps-1 x state_dim x state_dim result = self.transition_logits.unsqueeze(-3) + value_logits[..., 1:, None, :] # Forward pass. state_logits[..., 0, :] = self.initial_logits + value_logits[..., 0, :] for i in range(1, shape[-2]): transit_weights = ( state_logits[..., i - 1, :, None] + result[..., i - 1, :, :] ) state_logits[..., i, :], state_traceback[..., i, :] = torch.max( transit_weights, -2 ) # Traceback. map_states = torch.zeros(shape[:-1], dtype=torch.long) map_states[..., -1] = torch.argmax(state_logits[..., -1, :], -1) for i in range(shape[-2] - 1, 0, -1): map_states[..., i - 1] = torch.gather( state_traceback[..., i, :], -1, map_states[..., i].unsqueeze(-1) ).squeeze(-1) return map_states
[docs] def given_states(self, states): """ Distribution conditional on the state variable. :param ~torch.Tensor map_states: State trajectory. Must be integer-valued (long) and broadcastable to ``(batch_size, num_steps)``. """ shape = broadcast_shape( list(self.batch_shape) + [1, 1], list(states.shape[:-1]) + [1, 1], [1, 1, self.observation_logits.shape[-1]], ) states_index = states.unsqueeze(-1) * torch.ones(shape, dtype=torch.long) obs_logits = self.observation_logits * torch.ones(shape) logits = torch.gather(obs_logits, -2, states_index) return OneHotCategorical(logits=logits)
[docs] def sample_given_states(self, states): """ Sample an observation conditional on the state variable. :param ~torch.Tensor map_states: State trajectory. Must be integer-valued (long) and broadcastable to ``(batch_size, num_steps)``. """ conditional = self.given_states(states) return conditional.sample()