# Source code for pyro.distributions.hmm

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
import torch.nn.functional as F

from pyro.ops.gamma_gaussian import (
GammaGaussian,
gamma_and_mvn_to_gamma_gaussian,
gamma_gaussian_tensordot,
matrix_and_mvn_to_gamma_gaussian,
)
from pyro.ops.gaussian import (
Gaussian,
gaussian_tensordot,
matrix_and_mvn_to_gaussian,
mvn_to_gaussian,
sequential_gaussian_filter_sample,
sequential_gaussian_tensordot,
)
from pyro.ops.indexing import Vindex
from pyro.ops.special import safe_log
from pyro.ops.tensor_utils import cholesky_solve, safe_cholesky

from . import constraints
from .torch import Categorical, Gamma, Independent, MultivariateNormal
from .torch_distribution import TorchDistribution
from .util import broadcast_shape, torch_jit_script_if_tracing

@torch_jit_script_if_tracing
def _linear_integrate(init, trans, shift):
"""
Integrate the inhomogeneous linear shifterence equation::

x[0] = init
x[t] = x[t-1] @ trans[t] + shift[t]

:return: An integrated tensor x[:, :].
"""
# xs: List[Tensor]
xs = []
x = init.unsqueeze(-2)
shift = shift.unsqueeze(-3)
for t in range(trans.size(-3)):
x = x @ trans[..., t, :, :] + shift[..., t, :]
xs.append(x)
return torch.cat(xs, dim=-2)

def _logmatmulexp(x, y):
"""
Numerically stable version of (x.exp() @ y.exp()).log().
"""
finfo = torch.finfo(x.dtype)  # avoid nan due to -inf - -inf
x_shift = x.detach().max(-1, keepdim=True).values.clamp_(min=finfo.min)
y_shift = y.detach().max(-2, keepdim=True).values.clamp_(min=finfo.min)
xy = safe_log(torch.matmul((x - x_shift).exp(), (y - y_shift).exp()))
return xy + x_shift + y_shift

# TODO re-enable jitting once _SafeLog is supported by the jit.
# See https://discuss.pytorch.org/t/does-torch-jit-script-support-custom-operators/65759/4
# @torch_jit_script_if_tracing
def _sequential_logmatmulexp(logits):
"""
For a tensor x whose time dimension is -3, computes::

x[..., 0, :, :] @ x[..., 1, :, :] @ ... @ x[..., T-1, :, :]

but does so numerically stably in log space.
"""
batch_shape = logits.shape[:-3]
state_dim = logits.size(-1)
while logits.size(-3) > 1:
time = logits.size(-3)
even_time = time // 2 * 2
even_part = logits[..., :even_time, :, :]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2, state_dim, state_dim))
x, y = x_y.unbind(-3)
contracted = _logmatmulexp(x, y)
if time > even_time:
contracted = torch.cat((contracted, logits[..., -1:, :, :]), dim=-3)
logits = contracted
return logits.squeeze(-3)

def _markov_index(x, y):
"""
Join ends of two Markov paths.
"""
y = Vindex(y.unsqueeze(-2))[..., x[..., -1:, :]]
return torch.cat([x, y], -2)

def _sequential_index(samples):
"""
For a tensor samples whose time dimension is -2 and state dimension
is -1, compute Markov paths by sequential indexing.

For example, for samples with 3 states and time duration 5::

tensor([[0, 1, 1],
[1, 0, 2],
[2, 1, 0],
[0, 2, 1],
[1, 1, 0]])

computed paths are::

tensor([[0, 1, 1],
[1, 0, 0],
[1, 2, 2],
[2, 1, 1],
[0, 1, 1]])

# path for a 0th state
#
# 0 1 1
# |
# 1 0 2
#  \
# 2 1 0
#   |
# 0 2 1
#    \
# 1 1 0
#
# paths for 1st and 2nd states
#
# 0 1 1
#   |/
# 1 0 2
#  /
# 2 1 0
#  \
#    \
# 0 2 1
#    /
# 1 1 0
"""
# new Markov time dimension at -2
samples = samples.unsqueeze(-2)
batch_shape = samples.shape[:-3]
state_dim = samples.size(-1)
duration = samples.size(-3)
while samples.size(-3) > 1:
time = samples.size(-3)
even_time = time // 2 * 2
even_part = samples[..., :even_time, :, :]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2, -1, state_dim))
x, y = x_y.unbind(-3)
contracted = _markov_index(x, y)
if time > even_time:
padded = F.pad(
input=samples[..., -1:, :, :],
pad=(0, 0, 0, contracted.size(-2) // 2),
)
contracted = torch.cat((contracted, padded), dim=-3)
samples = contracted
return samples.squeeze(-3)[..., :duration, :]

def _sequential_gamma_gaussian_tensordot(gamma_gaussian):
"""
Integrates a GammaGaussian x whose rightmost batch dimension is time, computes::

x[..., 0] @ x[..., 1] @ ... @ x[..., T-1]
"""
assert isinstance(gamma_gaussian, GammaGaussian)
assert gamma_gaussian.dim() % 2 == 0, "dim is not even"
batch_shape = gamma_gaussian.batch_shape[:-1]
state_dim = gamma_gaussian.dim() // 2
while gamma_gaussian.batch_shape[-1] > 1:
time = gamma_gaussian.batch_shape[-1]
even_time = time // 2 * 2
even_part = gamma_gaussian[..., :even_time]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2))
x, y = x_y[..., 0], x_y[..., 1]
contracted = gamma_gaussian_tensordot(x, y, state_dim)
if time > even_time:
contracted = GammaGaussian.cat(
(contracted, gamma_gaussian[..., -1:]), dim=-1
)
gamma_gaussian = contracted
return gamma_gaussian[..., 0]

class HiddenMarkovModel(TorchDistribution):
"""
Abstract base class for Hidden Markov Models.

The purpose of this class is to handle duration logic for homogeneous HMMs.

:param int duration: Optional size of the time axis event_shape[0].
This is required when sampling from homogeneous HMMs whose parameters
are not expanded along the time axis.
"""

def __init__(self, duration, batch_shape, event_shape, validate_args=None):
if duration is None:
if event_shape[0] != 1:
# Infer duration from event_shape.
duration = event_shape[0]
elif duration != event_shape[0]:
if event_shape[0] != 1:
raise ValueError(
"duration, event_shape mismatch: {} vs {}".format(
duration, event_shape
)
)
# Infer event_shape from duration.
event_shape = torch.Size((duration,) + event_shape[1:])
self._duration = duration
super().__init__(batch_shape, event_shape, validate_args)

@property
def duration(self):
"""
Returns the size of the time axis, or None if unknown.
"""
return self._duration

def _validate_sample(self, value):
if value.dim() < self.event_dim:
raise ValueError("value has too few dimensions: {}".format(value.shape))

if self.duration is not None:
super()._validate_sample(value)
return

# Temporarily infer duration from value.shape.
duration = value.size(-self.event_dim)
old = self._event_shape
new = torch.Size((duration,)) + self._event_shape[1:]
try:
self._event_shape = new
super()._validate_sample(value)
finally:
self._event_shape = old

[docs]class DiscreteHMM(HiddenMarkovModel):
"""
Hidden Markov Model with discrete latent state and arbitrary observation
distribution.

This uses [1] to parallelize over time, achieving O(log(time)) parallel
complexity for computing :meth:log_prob, :meth:filter, and :meth:sample.

The event_shape of this distribution includes time on the left::

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous
time dependency of transition_logits and observation_dist. However,
because time is included in this distribution's event_shape, the
homogeneous+homogeneous case will have a broadcastable event_shape with
num_steps = 1, allowing :meth:log_prob to work with arbitrary length
data::

# homogeneous + homogeneous case:
event_shape = (1,) + observation_dist.event_shape

**References:**

[1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
"Temporal Parallelization of Bayesian Filters and Smoothers"
https://arxiv.org/pdf/1905.13002.pdf

:param ~torch.Tensor initial_logits: A logits tensor for an initial
categorical distribution over latent states. Should have rightmost size
state_dim and be broadcastable to batch_shape + (state_dim,).
:param ~torch.Tensor transition_logits: A logits tensor for transition
conditional distributions between latent states. Should have rightmost
shape (state_dim, state_dim) (old, new), and be broadcastable to
batch_shape + (num_steps, state_dim, state_dim).
:param ~torch.distributions.Distribution observation_dist: A conditional
distribution of observed data conditioned on latent state. The
.batch_shape should have rightmost size state_dim and be
broadcastable to batch_shape + (num_steps, state_dim). The
.event_shape may be arbitrary.
:param int duration: Optional size of the time axis event_shape[0].
This is required when sampling from homogeneous HMMs whose parameters
are not expanded along the time axis.
"""

arg_constraints = {
"initial_logits": constraints.real,
"transition_logits": constraints.real,
}

def __init__(
self,
initial_logits,
transition_logits,
observation_dist,
validate_args=None,
duration=None,
):
if initial_logits.dim() < 1:
raise ValueError(
"expected initial_logits to have at least one dim, "
"actual shape = {}".format(initial_logits.shape)
)
if transition_logits.dim() < 2:
raise ValueError(
"expected transition_logits to have at least two dims, "
"actual shape = {}".format(transition_logits.shape)
)
if len(observation_dist.batch_shape) < 1:
raise ValueError(
"expected observation_dist to have at least one batch dim, "
"actual .batch_shape = {}".format(observation_dist.batch_shape)
)
shape = broadcast_shape(
initial_logits.shape[:-1] + (1,),
transition_logits.shape[:-2],
observation_dist.batch_shape[:-1],
)
batch_shape, time_shape = shape[:-1], shape[-1:]
event_shape = time_shape + observation_dist.event_shape
self.initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
self.transition_logits = transition_logits - transition_logits.logsumexp(
-1, True
)
self.observation_dist = observation_dist
super().__init__(
duration, batch_shape, event_shape, validate_args=validate_args
)

@constraints.dependent_property(event_dim=2)
def support(self):
return constraints.independent(self.observation_dist.support, 1)

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(DiscreteHMM, _instance)
batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape))
# We only need to expand one of the inputs, since batch_shape is determined
# by broadcasting all three. To save computation in _sequential_logmatmulexp(),
# we expand only initial_logits, which is applied only after the logmatmulexp.
# This is similar to the ._unbroadcasted_* pattern used elsewhere in distributions.
new.initial_logits = self.initial_logits.expand(batch_shape + (-1,))
new.transition_logits = self.transition_logits
new.observation_dist = self.observation_dist
super(DiscreteHMM, new).__init__(
self.duration, batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new

[docs]    def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

# Combine observation and transition factors.
value = value.unsqueeze(-1 - self.observation_dist.event_dim)
observation_logits = self.observation_dist.log_prob(value)
result = self.transition_logits + observation_logits.unsqueeze(-2)

# Eliminate time dimension.
result = _sequential_logmatmulexp(result)

# Combine initial factor.
result = self.initial_logits + result.logsumexp(-1)

# Marginalize out final state.
result = result.logsumexp(-1)
return result

[docs]    def filter(self, value):
"""
Compute posterior over final state given a sequence of observations.

:param ~torch.Tensor value: A sequence of observations.
:return: A posterior distribution
over latent states at the final time step. result.logits can
then be used as initial_logits in a sequential Pyro model for
prediction.
:rtype: ~pyro.distributions.Categorical
"""
if self._validate_args:
self._validate_sample(value)

# Combine observation and transition factors.
value = value.unsqueeze(-1 - self.observation_dist.event_dim)
observation_logits = self.observation_dist.log_prob(value)
logp = self.transition_logits + observation_logits.unsqueeze(-2)

# Eliminate time dimension.
logp = _sequential_logmatmulexp(logp)

# Combine initial factor.
logp = (self.initial_logits.unsqueeze(-1) + logp).logsumexp(-2)

# Convert to a distribution.
return Categorical(logits=logp, validate_args=self._validate_args)

[docs]    @torch.no_grad()
def sample(self, sample_shape=torch.Size()):
assert self.duration is not None

# Sample initial state.
S = self.initial_logits.size(-1)  # state space size
init_shape = torch.Size(sample_shape) + self.batch_shape + (S,)
init_logits = self.initial_logits.expand(init_shape)
x = Categorical(logits=init_logits).sample()

# Sample hidden states over time.
trans_shape = (
torch.Size(sample_shape) + self.batch_shape + (self.duration, S, S)
)
trans_logits = self.transition_logits.expand(trans_shape)
xs = Categorical(logits=trans_logits).sample()
xs = _sequential_index(xs)
x = Vindex(xs)[..., :, x]

# Sample observations conditioned on hidden states.
# Note the simple sample-then-slice approach here generalizes to all
# distributions, but is inefficient. To implement a general optimal
# slice-then-sample strategy would require distributions to support
# slicing https://github.com/pyro-ppl/pyro/issues/3052. A simpler
# implementation might register a few slicing operators as is done with
# pyro.contrib.forecast.util.reshape_batch(). If you as a user need
# this function to be cheaper, feel free to submit a PR implementing
# one of these approaches.
obs_shape = self.batch_shape + (self.duration, S)
obs_dist = self.observation_dist.expand(obs_shape)
y = obs_dist.sample(sample_shape)
y = Vindex(y)[(Ellipsis, x) + (slice(None),) * obs_dist.event_dim]
return y

[docs]class GaussianHMM(HiddenMarkovModel):
"""
Hidden Markov Model with Gaussians for initial, transition, and observation
distributions. This adapts [1] to parallelize over time to achieve
O(log(time)) parallel complexity, however it differs in that it tracks the
log normalizer to ensure :meth:log_prob is differentiable.

This corresponds to the generative model::

z = initial_distribution.sample()
x = []
for t in range(num_events):
z = z @ transition_matrix + transition_dist.sample()
x.append(z @ observation_matrix + observation_dist.sample())

The event_shape of this distribution includes time on the left::

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous
time dependency of transition_dist and observation_dist. However,
because time is included in this distribution's event_shape, the
homogeneous+homogeneous case will have a broadcastable event_shape with
num_steps = 1, allowing :meth:log_prob to work with arbitrary length
data::

event_shape = (1, obs_dim)  # homogeneous + homogeneous case

**References:**

[1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
"Temporal Parallelization of Bayesian Filters and Smoothers"
https://arxiv.org/pdf/1905.13002.pdf

:ivar int hidden_dim: The dimension of the hidden state.
:ivar int obs_dim: The dimension of the observed state.
:param ~torch.distributions.MultivariateNormal initial_dist: A distribution
over initial states. This should have batch_shape broadcastable to
self.batch_shape.  This should have event_shape (hidden_dim,).
:param ~torch.Tensor transition_matrix: A linear transformation of hidden
state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, hidden_dim) where the
rightmost dims are ordered (old, new).
:param ~torch.distributions.MultivariateNormal transition_dist: A process
noise distribution. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,).  This should have event_shape
(hidden_dim,).
:param ~torch.Tensor observation_matrix: A linear transformation from hidden
to observed state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, obs_dim).
:param observation_dist: An observation noise distribution. This should
have batch_shape broadcastable to self.batch_shape + (num_steps,).
This should have event_shape (obs_dim,).
:type observation_dist: ~torch.distributions.MultivariateNormal or
~torch.distributions.Independent of ~torch.distributions.Normal
:param int duration: Optional size of the time axis event_shape[0].
This is required when sampling from homogeneous HMMs whose parameters
are not expanded along the time axis.
"""

has_rsample = True
arg_constraints = {}
support = constraints.independent(constraints.real, 2)

def __init__(
self,
initial_dist,
transition_matrix,
transition_dist,
observation_matrix,
observation_dist,
validate_args=None,
duration=None,
):
assert isinstance(initial_dist, torch.distributions.MultivariateNormal) or (
isinstance(initial_dist, torch.distributions.Independent)
and isinstance(initial_dist.base_dist, torch.distributions.Normal)
)
assert isinstance(transition_matrix, torch.Tensor)
assert isinstance(transition_dist, torch.distributions.MultivariateNormal) or (
isinstance(transition_dist, torch.distributions.Independent)
and isinstance(transition_dist.base_dist, torch.distributions.Normal)
)
assert isinstance(observation_matrix, torch.Tensor)
assert isinstance(observation_dist, torch.distributions.MultivariateNormal) or (
isinstance(observation_dist, torch.distributions.Independent)
and isinstance(observation_dist.base_dist, torch.distributions.Normal)
)
hidden_dim, obs_dim = observation_matrix.shape[-2:]
assert initial_dist.event_shape == (hidden_dim,)
assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
assert transition_dist.event_shape == (hidden_dim,)
assert observation_dist.event_shape == (obs_dim,)
shape = broadcast_shape(
initial_dist.batch_shape + (1,),
transition_matrix.shape[:-2],
transition_dist.batch_shape,
observation_matrix.shape[:-2],
observation_dist.batch_shape,
)
batch_shape, time_shape = shape[:-1], shape[-1:]
event_shape = time_shape + (obs_dim,)
super().__init__(
duration, batch_shape, event_shape, validate_args=validate_args
)

self.hidden_dim = hidden_dim
self.obs_dim = obs_dim
self._init = mvn_to_gaussian(initial_dist).expand(self.batch_shape)
self._trans = matrix_and_mvn_to_gaussian(transition_matrix, transition_dist)
self._obs = matrix_and_mvn_to_gaussian(observation_matrix, observation_dist)

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(GaussianHMM, _instance)
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
new._obs = self._obs
new._trans = self._trans

# To save computation in sequential_gaussian_tensordot(), we expand
# only _init, which is applied only after
# sequential_gaussian_tensordot().
batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape))
new._init = self._init.expand(batch_shape)

super(GaussianHMM, new).__init__(
self.duration, batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new

[docs]    def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

# Combine observation and transition factors.
result = self._trans + self._obs.condition(value).event_pad(
left=self.hidden_dim
)

# Eliminate time dimension.
result = sequential_gaussian_tensordot(result.expand(result.batch_shape))

# Combine initial factor.
result = gaussian_tensordot(self._init, result, dims=self.hidden_dim)

# Marginalize out final state.
result = result.event_logsumexp()
return result

[docs]    def rsample(self, sample_shape=torch.Size()):
assert self.duration is not None
sample_shape = torch.Size(sample_shape)
trans = self._trans + self._obs.marginalize(right=self.obs_dim).event_pad(
left=self.hidden_dim
)
trans = trans.expand(trans.batch_shape[:-1] + (self.duration,))
z = sequential_gaussian_filter_sample(self._init, trans, sample_shape)
z = z[..., 1:, :]  # drop the initial hidden state
x = self._obs.left_condition(z).rsample()
return x

[docs]    def rsample_posterior(self, value, sample_shape=torch.Size()):
"""
EXPERIMENTAL Sample from the latent state conditioned on observation.
"""
trans = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim)
trans = trans.expand(trans.batch_shape)
z = sequential_gaussian_filter_sample(self._init, trans, sample_shape)
z = z[..., 1:, :]  # drop the initial hidden state
return z

[docs]    def filter(self, value):
"""
Compute posterior over final state given a sequence of observations.

:param ~torch.Tensor value: A sequence of observations.
:return: A posterior
distribution over latent states at the final time step. result
can then be used as initial_dist in a sequential Pyro model for
prediction.
:rtype: ~pyro.distributions.MultivariateNormal
"""
if self._validate_args:
self._validate_sample(value)

# Combine observation and transition factors.
logp = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim)

# Eliminate time dimension.
logp = sequential_gaussian_tensordot(logp.expand(logp.batch_shape))

# Combine initial factor.
logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim)

# Convert to a distribution
precision = logp.precision
loc = cholesky_solve(
logp.info_vec.unsqueeze(-1), safe_cholesky(precision)
).squeeze(-1)
return MultivariateNormal(
loc, precision_matrix=precision, validate_args=self._validate_args
)

[docs]    def conjugate_update(self, other):
"""
EXPERIMENTAL Creates an updated :class:GaussianHMM fusing information
from another compatible distribution.

This should satisfy::

fg, log_normalizer = f.conjugate_update(g)
assert f.log_prob(x) + g.log_prob(x) == fg.log_prob(x) + log_normalizer

:param other: A distribution representing p(data|self.probs) but
normalized over self.probs rather than data.
:type other: ~torch.distributions.Independent of
~torch.distributions.MultivariateNormal or ~torch.distributions.Normal
:return: a pair (updated,log_normalizer) where updated is an
updated :class:GaussianHMM , and log_normalizer is a
:class:~torch.Tensor representing the normalization factor.
"""
assert isinstance(other, torch.distributions.Independent) and (
isinstance(other.base_dist, torch.distributions.Normal)
or isinstance(other.base_dist, torch.distributions.MultivariateNormal)
)
duration = other.event_shape[0] if self.duration is None else self.duration
event_shape = torch.Size((duration, self.obs_dim))
assert other.event_shape == event_shape

new = self._get_checked_instance(GaussianHMM)
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
new._init = self._init
new._trans = self._trans
new._obs = self._obs + mvn_to_gaussian(other.to_event(-1)).event_pad(
left=self.hidden_dim
)

# Normalize.
# TODO cache this computation for the forward pass of .rsample().
logp = new._trans + new._obs.marginalize(right=new.obs_dim).event_pad(
left=new.hidden_dim
)
logp = sequential_gaussian_tensordot(logp.expand(logp.batch_shape))
logp = gaussian_tensordot(new._init, logp, dims=new.hidden_dim)
log_normalizer = logp.event_logsumexp()
new._init = new._init - log_normalizer

batch_shape = log_normalizer.shape
super(GaussianHMM, new).__init__(
duration, batch_shape, event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new, log_normalizer

[docs]    def prefix_condition(self, data):
"""
EXPERIMENTAL Given self has event_shape == (t+f, d) and data x
of shape batch_shape + (t, d), compute a conditional distribution
of event_shape (f, d). Typically t is the number of training
time steps, f is the number of forecast time steps, and d is
the data dimension.

:param data: data of dimension at least 2.
:type data: ~torch.Tensor
"""
assert data.dim() >= 2
assert data.size(-1) == self.event_shape[-1]
assert data.size(-2) < self.duration
t = data.size(-2)
f = self.duration - t

left = self._get_checked_instance(GaussianHMM)
left.hidden_dim = self.hidden_dim
left.obs_dim = self.obs_dim
left._init = self._init

right = self._get_checked_instance(GaussianHMM)
right.hidden_dim = self.hidden_dim
right.obs_dim = self.obs_dim

if self._obs.batch_shape == () or self._obs.batch_shape[-1] == 1:  # homogeneous
left._obs = self._obs
right._obs = self._obs
else:  # heterogeneous
left._obs = self._obs[..., :t]
right._obs = self._obs[..., t:]

if (
self._trans.batch_shape == () or self._trans.batch_shape[-1] == 1
):  # homogeneous
left._trans = self._trans
right._trans = self._trans
else:  # heterogeneous
left._trans = self._trans[..., :t]
right._trans = self._trans[..., t:]

super(GaussianHMM, left).__init__(
t, self.batch_shape, (t, self.obs_dim), validate_args=self._validate_args
)
initial_dist = left.filter(data)
right._init = mvn_to_gaussian(initial_dist)
batch_shape = broadcast_shape(right._init.batch_shape, self.batch_shape)
super(GaussianHMM, right).__init__(
f, batch_shape, (f, self.obs_dim), validate_args=self._validate_args
)
return right

[docs]class GammaGaussianHMM(HiddenMarkovModel):
"""
Hidden Markov Model with the joint distribution of initial state, hidden
state, and observed state is a :class:~pyro.distributions.MultivariateStudentT
distribution along the line of references [2] and [3]. This adapts [1]
to parallelize over time to achieve O(log(time)) parallel complexity.

This GammaGaussianHMM class corresponds to the generative model::

s = Gamma(df/2, df/2).sample()
z = scale(initial_dist, s).sample()
x = []
for t in range(num_events):
z = z @ transition_matrix + scale(transition_dist, s).sample()
x.append(z @ observation_matrix + scale(observation_dist, s).sample())

where scale(mvn(loc, precision), s) := mvn(loc, s * precision).

The event_shape of this distribution includes time on the left::

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous
time dependency of transition_dist and observation_dist. However,
because time is included in this distribution's event_shape, the
homogeneous+homogeneous case will have a broadcastable event_shape with
num_steps = 1, allowing :meth:log_prob to work with arbitrary length
data::

event_shape = (1, obs_dim)  # homogeneous + homogeneous case

**References:**

[1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
"Temporal Parallelization of Bayesian Filters and Smoothers"
https://arxiv.org/pdf/1905.13002.pdf

[2] F. J. Giron and J. C. Rojano (1994)
"Bayesian Kalman filtering with elliptically contoured errors"

[3] Filip Tronarp, Toni Karvonen, and Simo Sarkka (2019)
"Student's t-filters for noise scale estimation"
https://users.aalto.fi/~ssarkka/pub/SPL2019.pdf

:ivar int hidden_dim: The dimension of the hidden state.
:ivar int obs_dim: The dimension of the observed state.
:param Gamma scale_dist: Prior of the mixing distribution.
:param MultivariateNormal initial_dist: A distribution with unit scale mixing
over initial states. This should have batch_shape broadcastable to
self.batch_shape.  This should have event_shape (hidden_dim,).
:param ~torch.Tensor transition_matrix: A linear transformation of hidden
state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, hidden_dim) where the
rightmost dims are ordered (old, new).
:param MultivariateNormal transition_dist: A process noise distribution
with unit scale mixing. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,). This should have event_shape
(hidden_dim,).
:param ~torch.Tensor observation_matrix: A linear transformation from hidden
to observed state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, obs_dim).
:param MultivariateNormal observation_dist: An observation noise distribution
with unit scale mixing. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,).
This should have event_shape (obs_dim,).
:param int duration: Optional size of the time axis event_shape[0].
This is required when sampling from homogeneous HMMs whose parameters
are not expanded along the time axis.
"""

arg_constraints = {}
support = constraints.independent(constraints.real, 2)

def __init__(
self,
scale_dist,
initial_dist,
transition_matrix,
transition_dist,
observation_matrix,
observation_dist,
validate_args=None,
duration=None,
):
assert isinstance(scale_dist, Gamma)
assert isinstance(initial_dist, MultivariateNormal)
assert isinstance(transition_matrix, torch.Tensor)
assert isinstance(transition_dist, MultivariateNormal)
assert isinstance(observation_matrix, torch.Tensor)
assert isinstance(observation_dist, MultivariateNormal)
hidden_dim, obs_dim = observation_matrix.shape[-2:]
assert initial_dist.event_shape == (hidden_dim,)
assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
assert transition_dist.event_shape == (hidden_dim,)
assert observation_dist.event_shape == (obs_dim,)
shape = broadcast_shape(
scale_dist.batch_shape + (1,),
initial_dist.batch_shape + (1,),
transition_matrix.shape[:-2],
transition_dist.batch_shape,
observation_matrix.shape[:-2],
observation_dist.batch_shape,
)
batch_shape, time_shape = shape[:-1], shape[-1:]
event_shape = time_shape + (obs_dim,)
super().__init__(
duration, batch_shape, event_shape, validate_args=validate_args
)
self.hidden_dim = hidden_dim
self.obs_dim = obs_dim
self._init = gamma_and_mvn_to_gamma_gaussian(scale_dist, initial_dist)
self._trans = matrix_and_mvn_to_gamma_gaussian(
transition_matrix, transition_dist
)
self._obs = matrix_and_mvn_to_gamma_gaussian(
observation_matrix, observation_dist
)

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(GammaGaussianHMM, _instance)
batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape))
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
# We only need to expand one of the inputs, since batch_shape is determined
# by broadcasting all three. To save computation in sequential_gaussian_tensordot(),
# we expand only _init, which is applied only after sequential_gaussian_tensordot().
new._init = self._init.expand(batch_shape)
new._trans = self._trans
new._obs = self._obs
super(GammaGaussianHMM, new).__init__(
self.duration, batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new

[docs]    def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

# Combine observation and transition factors.
result = self._trans + self._obs.condition(value).event_pad(
left=self.hidden_dim
)

# Eliminate time dimension.
result = _sequential_gamma_gaussian_tensordot(result.expand(result.batch_shape))

# Combine initial factor.
result = gamma_gaussian_tensordot(self._init, result, dims=self.hidden_dim)

# Marginalize out final state.
result = result.event_logsumexp()

# Marginalize out multiplier.
result = result.logsumexp()
return result

[docs]    def filter(self, value):
"""
Compute posteriors over the multiplier and the final state
given a sequence of observations. The posterior is a pair of
Gamma and MultivariateNormal distributions (i.e. a GammaGaussian
instance).

:param ~torch.Tensor value: A sequence of observations.
:return: A pair of posterior distributions over the mixing and the latent
state at the final time step.
:rtype: a tuple of ~pyro.distributions.Gamma and ~pyro.distributions.MultivariateNormal
"""
if self._validate_args:
self._validate_sample(value)

# Combine observation and transition factors.
logp = self._trans + self._obs.condition(value).event_pad(left=self.hidden_dim)

# Eliminate time dimension.
logp = _sequential_gamma_gaussian_tensordot(logp.expand(logp.batch_shape))

# Combine initial factor.
logp = gamma_gaussian_tensordot(self._init, logp, dims=self.hidden_dim)

# Posterior of the scale
gamma_dist = logp.event_logsumexp()
scale_post = Gamma(
gamma_dist.concentration, gamma_dist.rate, validate_args=self._validate_args
)
# Conditional of last state on unit scale
scale_tril = safe_cholesky(logp.precision)
loc = cholesky_solve(logp.info_vec.unsqueeze(-1), scale_tril).squeeze(-1)
mvn = MultivariateNormal(
loc, scale_tril=scale_tril, validate_args=self._validate_args
)
return scale_post, mvn

[docs]class LinearHMM(HiddenMarkovModel):
r"""
Hidden Markov Model with linear dynamics and observations and arbitrary
noise for initial, transition, and observation distributions.  Each of
those distributions can be e.g.
:class:~pyro.distributions.MultivariateNormal or
:class:~pyro.distributions.Independent of
:class:~pyro.distributions.Normal,
:class:~pyro.distributions.StudentT, or :class:~pyro.distributions.Stable .
Additionally the observation distribution may be constrained, e.g.
:class:~pyro.distributions.LogNormal

This corresponds to the generative model::

z = initial_distribution.sample()
x = []
for t in range(num_events):
z = z @ transition_matrix + transition_dist.sample()
y = z @ observation_matrix + obs_base_dist.sample()
x.append(obs_transform(y))

where observation_dist is split into obs_base_dist and an optional
obs_transform (defaulting to the identity).

This implements a reparameterized :meth:rsample method but does not
implement a :meth:log_prob method. Derived classes may implement
:meth:log_prob .

Inference without :meth:log_prob can be performed using either
reparameterization with :class:~pyro.infer.reparam.hmm.LinearHMMReparam
or likelihood-free algorithms such as
:class:~pyro.infer.energy_distance.EnergyDistance .  Note that while
stable processes generally require a common shared stability parameter
:math:\alpha , this distribution and the above inference algorithms allow
heterogeneous stability parameters.

The event_shape of this distribution includes time on the left::

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous
time dependency of transition_dist and observation_dist. However at
least one of the distributions or matrices must be expanded to contain the
time dimension.

:ivar int hidden_dim: The dimension of the hidden state.
:ivar int obs_dim: The dimension of the observed state.
:param initial_dist: A distribution over initial states. This should have
batch_shape broadcastable to self.batch_shape.  This should have
event_shape (hidden_dim,).
:param ~torch.Tensor transition_matrix: A linear transformation of hidden
state. This should have shape broadcastable to self.batch_shape +
(num_steps, hidden_dim, hidden_dim) where the rightmost dims are
ordered (old, new).
:param transition_dist: A distribution over process noise. This should have
batch_shape broadcastable to self.batch_shape + (num_steps,).  This
should have event_shape (hidden_dim,).
:param ~torch.Tensor observation_matrix: A linear transformation from hidden
to observed state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, obs_dim).
:param observation_dist: A observation noise distribution. This should have
batch_shape broadcastable to self.batch_shape + (num_steps,).  This
should have event_shape (obs_dim,).
:param int duration: Optional size of the time axis event_shape[0].
This is required when sampling from homogeneous HMMs whose parameters
are not expanded along the time axis.
"""

arg_constraints = {}
support = constraints.independent(constraints.real, 2)
has_rsample = True

def __init__(
self,
initial_dist,
transition_matrix,
transition_dist,
observation_matrix,
observation_dist,
validate_args=None,
duration=None,
):
assert initial_dist.has_rsample
assert initial_dist.event_dim == 1
assert (
isinstance(transition_matrix, torch.Tensor) and transition_matrix.dim() >= 2
)
assert transition_dist.has_rsample
assert transition_dist.event_dim == 1
assert (
isinstance(observation_matrix, torch.Tensor)
and observation_matrix.dim() >= 2
)
assert observation_dist.has_rsample
assert observation_dist.event_dim == 1

hidden_dim, obs_dim = observation_matrix.shape[-2:]
assert initial_dist.event_shape == (hidden_dim,)
assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
assert transition_dist.event_shape == (hidden_dim,)
assert observation_dist.event_shape == (obs_dim,)
shape = broadcast_shape(
initial_dist.batch_shape + (1,),
transition_matrix.shape[:-2],
transition_dist.batch_shape,
observation_matrix.shape[:-2],
observation_dist.batch_shape,
)
batch_shape, time_shape = shape[:-1], shape[-1:]
event_shape = time_shape + (obs_dim,)
super().__init__(
duration, batch_shape, event_shape, validate_args=validate_args
)

# Expand eagerly.
if initial_dist.batch_shape != batch_shape:
initial_dist = initial_dist.expand(batch_shape)
if transition_matrix.shape[:-2] != batch_shape + time_shape:
transition_matrix = transition_matrix.expand(
batch_shape + time_shape + (hidden_dim, hidden_dim)
)
if transition_dist.batch_shape != batch_shape + time_shape:
transition_dist = transition_dist.expand(batch_shape + time_shape)
if observation_matrix.shape[:-2] != batch_shape + time_shape:
observation_matrix = observation_matrix.expand(
batch_shape + time_shape + (hidden_dim, obs_dim)
)
if observation_dist.batch_shape != batch_shape + time_shape:
observation_dist = observation_dist.expand(batch_shape + time_shape)

# Extract observation transforms.
transforms = []
while True:
if isinstance(observation_dist, torch.distributions.Independent):
observation_dist = observation_dist.base_dist
elif isinstance(
observation_dist, torch.distributions.TransformedDistribution
):
transforms = observation_dist.transforms + transforms
observation_dist = observation_dist.base_dist
else:
break
if not observation_dist.event_shape:
observation_dist = Independent(observation_dist, 1)

self.hidden_dim = hidden_dim
self.obs_dim = obs_dim
self.initial_dist = initial_dist
self.transition_matrix = transition_matrix
self.transition_dist = transition_dist
self.observation_matrix = observation_matrix
self.observation_dist = observation_dist
self.transforms = transforms

@constraints.dependent_property(event_dim=2)
def support(self):
return constraints.independent(self.observation_dist.support, 1)

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LinearHMM, _instance)
batch_shape = torch.Size(batch_shape)
time_shape = self.transition_dist.batch_shape[-1:]
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
new.initial_dist = self.initial_dist.expand(batch_shape)
new.transition_matrix = self.transition_matrix.expand(
batch_shape + time_shape + (self.hidden_dim, self.hidden_dim)
)
new.transition_dist = self.transition_dist.expand(batch_shape + time_shape)
new.observation_matrix = self.observation_matrix.expand(
batch_shape + time_shape + (self.hidden_dim, self.obs_dim)
)
new.observation_dist = self.observation_dist.expand(batch_shape + time_shape)
new.transforms = self.transforms
super(LinearHMM, new).__init__(
self.duration, batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new

[docs]    def log_prob(self, value):
raise NotImplementedError("LinearHMM.log_prob() is not implemented")

[docs]    def rsample(self, sample_shape=torch.Size()):
assert self.duration is not None
init = self.initial_dist.rsample(sample_shape)
trans = self.transition_dist.expand(
self.batch_shape + (self.duration,)
).rsample(sample_shape)
obs = self.observation_dist.expand(self.batch_shape + (self.duration,)).rsample(
sample_shape
)
trans_matrix = self.transition_matrix.expand(
self.batch_shape + (self.duration, -1, -1)
)
z = _linear_integrate(init, trans_matrix, trans)
x = (z.unsqueeze(-2) @ self.observation_matrix).squeeze(-2) + obs
for t in self.transforms:
x = t(x)
return x

[docs]class IndependentHMM(TorchDistribution):
"""
Wrapper class to treat a batch of independent univariate HMMs as a single
multivariate distribution. This converts distribution shapes as follows:

+-----------+--------------------+---------------------+
|           |       .batch_shape | .event_shape        |
+===========+====================+=====================+
| base_dist | shape + (obs_dim,) | (duration, 1)       |
+-----------+--------------------+---------------------+
|    result |              shape | (duration, obs_dim) |
+-----------+--------------------+---------------------+

:param HiddenMarkovModel base_dist: A base hidden Markov model instance.
"""

arg_constraints = {}

def __init__(self, base_dist):
assert base_dist.batch_shape
assert base_dist.event_dim == 2
assert base_dist.event_shape[-1] == 1
batch_shape = base_dist.batch_shape[:-1]
event_shape = base_dist.event_shape[:-1] + base_dist.batch_shape[-1:]
super().__init__(batch_shape, event_shape)
self.base_dist = base_dist

@constraints.dependent_property(event_dim=2)
def support(self):
return self.base_dist.support

@property
def has_rsample(self):
return self.base_dist.has_rsample

@property
def duration(self):
return self.base_dist.duration

[docs]    def expand(self, batch_shape, _instance=None):
batch_shape = torch.Size(batch_shape)
new = self._get_checked_instance(IndependentHMM, _instance)
new.base_dist = self.base_dist.expand(
batch_shape + self.base_dist.batch_shape[-1:]
)
super(IndependentHMM, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new

[docs]    def rsample(self, sample_shape=torch.Size()):
base_value = self.base_dist.rsample(sample_shape)
return base_value.squeeze(-1).transpose(-1, -2)

[docs]    def log_prob(self, value):
base_value = value.transpose(-1, -2).unsqueeze(-1)
return self.base_dist.log_prob(base_value).sum(-1)

[docs]class GaussianMRF(TorchDistribution):
"""
Temporal Markov Random Field with Gaussian factors for initial, transition,
and observation distributions. This adapts [1] to parallelize over time to
achieve O(log(time)) parallel complexity, however it differs in that it
tracks the log normalizer to ensure :meth:log_prob is differentiable.

The event_shape of this distribution includes time on the left::

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous
time dependency of transition_dist and observation_dist. However,
because time is included in this distribution's event_shape, the
homogeneous+homogeneous case will have a broadcastable event_shape with
num_steps = 1, allowing :meth:log_prob to work with arbitrary length
data::

event_shape = (1, obs_dim)  # homogeneous + homogeneous case

**References:**

[1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
"Temporal Parallelization of Bayesian Filters and Smoothers"
https://arxiv.org/pdf/1905.13002.pdf

:ivar int hidden_dim: The dimension of the hidden state.
:ivar int obs_dim: The dimension of the observed state.
:param ~torch.distributions.MultivariateNormal initial_dist: A distribution
over initial states. This should have batch_shape broadcastable to
self.batch_shape.  This should have event_shape (hidden_dim,).
:param ~torch.distributions.MultivariateNormal transition_dist: A joint
distribution factor over a pair of successive time steps. This should
have batch_shape broadcastable to self.batch_shape + (num_steps,).
This should have event_shape (hidden_dim + hidden_dim,) (old+new).
:param ~torch.distributions.MultivariateNormal observation_dist: A joint
distribution factor over a hidden and an observed state. This should
have batch_shape broadcastable to self.batch_shape + (num_steps,).
This should have event_shape (hidden_dim + obs_dim,).
"""

arg_constraints = {}

def __init__(
self, initial_dist, transition_dist, observation_dist, validate_args=None
):
assert isinstance(initial_dist, torch.distributions.MultivariateNormal)
assert isinstance(transition_dist, torch.distributions.MultivariateNormal)
assert isinstance(observation_dist, torch.distributions.MultivariateNormal)
hidden_dim = initial_dist.event_shape[0]
assert transition_dist.event_shape[0] == hidden_dim + hidden_dim
obs_dim = observation_dist.event_shape[0] - hidden_dim

shape = broadcast_shape(
initial_dist.batch_shape + (1,),
transition_dist.batch_shape,
observation_dist.batch_shape,
)
batch_shape, time_shape = shape[:-1], shape[-1:]
event_shape = time_shape + (obs_dim,)
super().__init__(batch_shape, event_shape, validate_args=validate_args)
self.hidden_dim = hidden_dim
self.obs_dim = obs_dim
self._init = mvn_to_gaussian(initial_dist)
self._trans = mvn_to_gaussian(transition_dist)
self._obs = mvn_to_gaussian(observation_dist)
self._support = constraints.independent(observation_dist.support, 1)

@constraints.dependent_property(event_dim=2)
def support(self):
return self._support

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(GaussianMRF, _instance)
batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape))
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
# We only need to expand one of the inputs, since batch_shape is determined
# by broadcasting all three. To save computation in sequential_gaussian_tensordot(),
# we expand only _init, which is applied only after sequential_gaussian_tensordot().
new._init = self._init.expand(batch_shape)
new._trans = self._trans
new._obs = self._obs
new._support = self._support
super(GaussianMRF, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new

[docs]    def log_prob(self, value):
# We compute a normalized distribution as p(obs,hidden) / p(hidden).
logp_oh = self._trans
logp_h = self._trans

# Combine observation and transition factors.
logp_oh += self._obs.condition(value).event_pad(left=self.hidden_dim)
logp_h += self._obs.marginalize(right=self.obs_dim).event_pad(
left=self.hidden_dim
)

# Concatenate p(obs,hidden) and p(hidden) into a single Gaussian.
batch_dim = 1 + max(len(self._init.batch_shape) + 1, len(logp_oh.batch_shape))
batch_shape = (1,) * (
batch_dim - len(logp_oh.batch_shape)
) + logp_oh.batch_shape
logp = Gaussian.cat([logp_oh.expand(batch_shape), logp_h.expand(batch_shape)])

# Eliminate time dimension.
logp = sequential_gaussian_tensordot(logp)

# Combine initial factor.
logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim)

# Marginalize out final state.
logp_oh, logp_h = logp.event_logsumexp()
return logp_oh - logp_h  # = log( p(obs,hidden) / p(hidden) )