Source code for pyro.distributions.hmm

import torch
from torch.distributions import constraints

from pyro.distributions.torch import Categorical, MultivariateNormal
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape
from pyro.ops.gaussian import Gaussian, gaussian_tensordot, matrix_and_mvn_to_gaussian, mvn_to_gaussian

def _logmatmulexp(x, y):
"""
Numerically stable version of (x.log() @ y.log()).exp().
"""
x_shift = x.max(-1, keepdim=True)
y_shift = y.max(-2, keepdim=True)
xy = torch.matmul((x - x_shift).exp(), (y - y_shift).exp()).log()
return xy + x_shift + y_shift

def _sequential_logmatmulexp(logits):
"""
For a tensor x whose time dimension is -3, computes::

x[..., 0, :, :] @ x[..., 1, :, :] @ ... @ x[..., T-1, :, :]

but does so numerically stably in log space.
"""
batch_shape = logits.shape[:-3]
state_dim = logits.size(-1)
while logits.size(-3) > 1:
time = logits.size(-3)
even_time = time // 2 * 2
even_part = logits[..., :even_time, :, :]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2, state_dim, state_dim))
x, y = x_y.unbind(-3)
contracted = _logmatmulexp(x, y)
if time > even_time:
contracted = torch.cat((contracted, logits[..., -1:, :, :]), dim=-3)
logits = contracted
return logits.squeeze(-3)

def _sequential_gaussian_tensordot(gaussian):
"""
Integrates a Gaussian x whose rightmost batch dimension is time, computes::

x[..., 0] @ x[..., 1] @ ... @ x[..., T-1]
"""
assert isinstance(gaussian, Gaussian)
assert gaussian.dim() % 2 == 0, "dim is not even"
batch_shape = gaussian.batch_shape[:-1]
state_dim = gaussian.dim() // 2
while gaussian.batch_shape[-1] > 1:
time = gaussian.batch_shape[-1]
even_time = time // 2 * 2
even_part = gaussian[..., :even_time]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2))
x, y = x_y[..., 0], x_y[..., 1]
contracted = gaussian_tensordot(x, y, state_dim)
if time > even_time:
contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1)
gaussian = contracted
return gaussian[..., 0]

[docs]class DiscreteHMM(TorchDistribution): """ Hidden Markov Model with discrete latent state and arbitrary observation distribution. This uses  to parallelize over time, achieving O(log(time)) parallel complexity. The event_shape of this distribution includes time on the left:: event_shape = (num_steps,) + observation_dist.event_shape This distribution supports any combination of homogeneous/heterogeneous time dependency of transition_logits and observation_dist. However, because time is included in this distribution's event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape with num_steps = 1, allowing :meth:log_prob to work with arbitrary length data:: # homogeneous + homogeneous case: event_shape = (1,) + observation_dist.event_shape **References:**  Simo Sarkka, Angel F. Garcia-Fernandez (2019) "Temporal Parallelization of Bayesian Filters and Smoothers" https://arxiv.org/pdf/1905.13002.pdf :param ~torch.Tensor initial_logits: A logits tensor for an initial categorical distribution over latent states. Should have rightmost size state_dim and be broadcastable to batch_shape + (state_dim,). :param ~torch.Tensor transition_logits: A logits tensor for transition conditional distributions between latent states. Should have rightmost shape (state_dim, state_dim) (old, new), and be broadcastable to batch_shape + (num_steps, state_dim, state_dim). :param ~torch.distributions.Distribution observation_dist: A conditional distribution of observed data conditioned on latent state. The .batch_shape should have rightmost size state_dim and be broadcastable to batch_shape + (num_steps, state_dim). The .event_shape may be arbitrary. """ arg_constraints = {"initial_logits": constraints.real, "transition_logits": constraints.real} def __init__(self, initial_logits, transition_logits, observation_dist, validate_args=None): if initial_logits.dim() < 1: raise ValueError("expected initial_logits to have at least one dim, " "actual shape = {}".format(initial_logits.shape)) if transition_logits.dim() < 2: raise ValueError("expected transition_logits to have at least two dims, " "actual shape = {}".format(transition_logits.shape)) if len(observation_dist.batch_shape) < 1: raise ValueError("expected observation_dist to have at least one batch dim, " "actual .batch_shape = {}".format(observation_dist.batch_shape)) shape = broadcast_shape(initial_logits.shape[:-1] + (1,), transition_logits.shape[:-2], observation_dist.batch_shape[:-1]) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + observation_dist.event_shape self.initial_logits = initial_logits - initial_logits.logsumexp(-1, True) self.transition_logits = transition_logits - transition_logits.logsumexp(-1, True) self.observation_dist = observation_dist super(DiscreteHMM, self).__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(DiscreteHMM, _instance) batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape)) # We only need to expand one of the inputs, since batch_shape is determined # by broadcasting all three. To save computation in _sequential_logmatmulexp(), # we expand only initial_logits, which is applied only after the logmatmulexp. # This is similar to the ._unbroadcasted_* pattern used elsewhere in distributions. new.initial_logits = self.initial_logits.expand(batch_shape + (-1,)) new.transition_logits = self.transition_logits new.observation_dist = self.observation_dist super(DiscreteHMM, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self.__dict__.get('_validate_args') return new
[docs] def log_prob(self, value): # Combine observation and transition factors. value = value.unsqueeze(-1 - self.observation_dist.event_dim) observation_logits = self.observation_dist.log_prob(value) result = self.transition_logits + observation_logits.unsqueeze(-2) # Eliminate time dimension. result = _sequential_logmatmulexp(result) # Combine initial factor. result = self.initial_logits + result.logsumexp(-1) # Marginalize out final state. result = result.logsumexp(-1) return result
[docs] def filter(self, value): """ Compute posterior over final state given a sequence of observations. :param ~torch.Tensor value: A sequence of observations. :return: A posterior distribution over latent states at the final time step. result.logits can then be used as initial_logits in a sequential Pyro model for prediction. :rtype: ~pyro.distributions.Categorical """ # Combine observation and transition factors. value = value.unsqueeze(-1 - self.observation_dist.event_dim) observation_logits = self.observation_dist.log_prob(value) logp = self.transition_logits + observation_logits.unsqueeze(-2) # Eliminate time dimension. logp = _sequential_logmatmulexp(logp) # Combine initial factor. logp = (self.initial_logits.unsqueeze(-1) + logp).logsumexp(-2) # Convert to a distribution. return Categorical(logits=logp, validate_args=self._validate_args)
[docs]class GaussianHMM(TorchDistribution): """ Hidden Markov Model with Gaussians for initial, transition, and observation distributions. This adapts  to parallelize over time to achieve O(log(time)) parallel complexity, however it differs in that it tracks the log normalizer to ensure :meth:log_prob is differentiable. This corresponds to the generative model:: z = initial_distribution.sample() x = [] for t in range(num_events): z = z @ transition_matrix + transition_dist.sample() x.append(z @ observation_matrix + observation_dist.sample()) The event_shape of this distribution includes time on the left:: event_shape = (num_steps,) + observation_dist.event_shape This distribution supports any combination of homogeneous/heterogeneous time dependency of transition_dist and observation_dist. However, because time is included in this distribution's event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape with num_steps = 1, allowing :meth:log_prob to work with arbitrary length data:: event_shape = (1, obs_dim) # homogeneous + homogeneous case **References:**  Simo Sarkka, Angel F. Garcia-Fernandez (2019) "Temporal Parallelization of Bayesian Filters and Smoothers" https://arxiv.org/pdf/1905.13002.pdf :ivar int hidden_dim: The dimension of the hidden state. :ivar int obs_dim: The dimension of the observed state. :param ~torch.distributions.MultivariateNormal initial_dist: A distribution over initial states. This should have batch_shape broadcastable to self.batch_shape. This should have event_shape (hidden_dim,). :param ~torch.Tensor transition_matrix: A linear transformation of hidden state. This should have shape broadcastable to self.batch_shape + (num_steps, hidden_dim, hidden_dim) where the rightmost dims are ordered (old, new). :param ~torch.distributions.MultivariateNormal transition_dist: A process noise distribution. This should have batch_shape broadcastable to self.batch_shape + (num_steps,). This should have event_shape (hidden_dim,). :param ~torch.Tensor observation_matrix: A linear transformation from hidden to observed state. This should have shape broadcastable to self.batch_shape + (num_steps, hidden_dim, obs_dim). :param observation_dist: An observation noise distribution. This should have batch_shape broadcastable to self.batch_shape + (num_steps,). This should have event_shape (obs_dim,). :type observation_dist: ~torch.distributions.MultivariateNormal or ~torch.distributions.Independent of ~torch.distributions.Normal """ arg_constraints = {} def __init__(self, initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert (isinstance(observation_dist, torch.distributions.MultivariateNormal) or (isinstance(observation_dist, torch.distributions.Independent) and isinstance(observation_dist.base_dist, torch.distributions.Normal))) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert initial_dist.event_shape == (hidden_dim,) assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_dist.event_shape == (hidden_dim,) assert observation_dist.event_shape == (obs_dim,) shape = broadcast_shape(initial_dist.batch_shape + (1,), transition_matrix.shape[:-2], transition_dist.batch_shape, observation_matrix.shape[:-2], observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim,) super(GaussianHMM, self).__init__(batch_shape, event_shape, validate_args=validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim self._init = mvn_to_gaussian(initial_dist) self._trans = matrix_and_mvn_to_gaussian(transition_matrix, transition_dist) self._obs = matrix_and_mvn_to_gaussian(observation_matrix, observation_dist)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(GaussianHMM, _instance) batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape)) new.hidden_dim = self.hidden_dim new.obs_dim = self.obs_dim # We only need to expand one of the inputs, since batch_shape is determined # by broadcasting all three. To save computation in _sequential_gaussian_tensordot(), # we expand only _init, which is applied only after _sequential_gaussian_tensordot(). new._init = self._init.expand(batch_shape) new._trans = self._trans new._obs = self._obs super(GaussianHMM, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self.__dict__.get('_validate_args') return new
[docs] def log_prob(self, value): # Combine observation and transition factors. result = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim) # Eliminate time dimension. result = _sequential_gaussian_tensordot(result.expand(result.batch_shape)) # Combine initial factor. result = gaussian_tensordot(self._init, result, dims=self.hidden_dim) # Marginalize out final state. result = result.event_logsumexp() return result
[docs] def filter(self, value): """ Compute posterior over final state given a sequence of observations. :param ~torch.Tensor value: A sequence of observations. :return: A posterior distribution over latent states at the final time step. result can then be used as initial_dist in a sequential Pyro model for prediction. :rtype: ~pyro.distributions.MultivariateNormal """ # Combine observation and transition factors. logp = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim) # Eliminate time dimension. logp = _sequential_gaussian_tensordot(logp.expand(logp.batch_shape)) # Combine initial factor. logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim) # Convert to a distribution precision = logp.precision loc = logp.info_vec.unsqueeze(-1).cholesky_solve(precision.cholesky()).squeeze(-1) return MultivariateNormal(loc, precision_matrix=precision, validate_args=self._validate_args)
[docs]class GaussianMRF(TorchDistribution): """ Temporal Markov Random Field with Gaussian factors for initial, transition, and observation distributions. This adapts  to parallelize over time to achieve O(log(time)) parallel complexity, however it differs in that it tracks the log normalizer to ensure :meth:log_prob is differentiable. The event_shape of this distribution includes time on the left:: event_shape = (num_steps,) + observation_dist.event_shape This distribution supports any combination of homogeneous/heterogeneous time dependency of transition_dist and observation_dist. However, because time is included in this distribution's event_shape, the homogeneous+homogeneous case will have a broadcastable event_shape with num_steps = 1, allowing :meth:log_prob to work with arbitrary length data:: event_shape = (1, obs_dim) # homogeneous + homogeneous case **References:**  Simo Sarkka, Angel F. Garcia-Fernandez (2019) "Temporal Parallelization of Bayesian Filters and Smoothers" https://arxiv.org/pdf/1905.13002.pdf :ivar int hidden_dim: The dimension of the hidden state. :ivar int obs_dim: The dimension of the observed state. :param ~torch.distributions.MultivariateNormal initial_dist: A distribution over initial states. This should have batch_shape broadcastable to self.batch_shape. This should have event_shape (hidden_dim,). :param ~torch.distributions.MultivariateNormal transition_dist: A joint distribution factor over a pair of successive time steps. This should have batch_shape broadcastable to self.batch_shape + (num_steps,). This should have event_shape (hidden_dim + hidden_dim,) (old+new). :param ~torch.distributions.MultivariateNormal observation_dist: A joint distribution factor over a hidden and an observed state. This should have batch_shape broadcastable to self.batch_shape + (num_steps,). This should have event_shape (hidden_dim + obs_dim,). """ arg_constraints = {} def __init__(self, initial_dist, transition_dist, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim = initial_dist.event_shape assert transition_dist.event_shape == hidden_dim + hidden_dim obs_dim = observation_dist.event_shape - hidden_dim shape = broadcast_shape(initial_dist.batch_shape + (1,), transition_dist.batch_shape, observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim,) super(GaussianMRF, self).__init__(batch_shape, event_shape, validate_args=validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim self._init = mvn_to_gaussian(initial_dist) self._trans = mvn_to_gaussian(transition_dist) self._obs = mvn_to_gaussian(observation_dist)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(GaussianMRF, _instance) batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape)) new.hidden_dim = self.hidden_dim new.obs_dim = self.obs_dim # We only need to expand one of the inputs, since batch_shape is determined # by broadcasting all three. To save computation in _sequential_gaussian_tensordot(), # we expand only _init, which is applied only after _sequential_gaussian_tensordot(). new._init = self._init.expand(batch_shape) new._trans = self._trans new._obs = self._obs super(GaussianMRF, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self.__dict__.get('_validate_args') return new
[docs] def log_prob(self, value): # We compute a normalized distribution as p(obs,hidden) / p(hidden). logp_oh = self._trans logp_h = self._trans # Combine observation and transition factors. logp_oh += self._obs.condition(value).event_pad(left=self.hidden_dim) logp_h += self._obs.marginalize(right=self.obs_dim).event_pad(left=self.hidden_dim) # Concatenate p(obs,hidden) and p(hidden) into a single Gaussian. batch_dim = 1 + max(len(self._init.batch_shape) + 1, len(logp_oh.batch_shape)) batch_shape = (1,) * (batch_dim - len(logp_oh.batch_shape)) + logp_oh.batch_shape logp = Gaussian.cat([logp_oh.expand(batch_shape), logp_h.expand(batch_shape)]) # Eliminate time dimension. logp = _sequential_gaussian_tensordot(logp) # Combine initial factor. logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim) # Marginalize out final state. logp_oh, logp_h = logp.event_logsumexp() return logp_oh - logp_h # = log( p(obs,hidden) / p(hidden) )