# Source code for pyro.distributions.hmm

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import torch
import torch.nn.functional as F

from pyro.ops.gamma_gaussian import (
GammaGaussian,
gamma_and_mvn_to_gamma_gaussian,
gamma_gaussian_tensordot,
matrix_and_mvn_to_gamma_gaussian,
)
from pyro.ops.gaussian import (
Gaussian,
gaussian_tensordot,
matrix_and_mvn_to_gaussian,
mvn_to_gaussian,
sequential_gaussian_filter_sample,
sequential_gaussian_tensordot,
)
from pyro.ops.indexing import Vindex
from pyro.ops.special import safe_log
from pyro.ops.tensor_utils import cholesky_solve, safe_cholesky

from . import constraints
from .torch import Categorical, Gamma, Independent, MultivariateNormal
from .torch_distribution import TorchDistribution

@torch_jit_script_if_tracing
def _linear_integrate(init, trans, shift):
"""
Integrate the inhomogeneous linear shifterence equation::

x = init
x[t] = x[t-1] @ trans[t] + shift[t]

:return: An integrated tensor x[:, :].
"""
# xs: List[Tensor]
xs = []
x = init.unsqueeze(-2)
shift = shift.unsqueeze(-3)
for t in range(trans.size(-3)):
x = x @ trans[..., t, :, :] + shift[..., t, :]
xs.append(x)

def _logmatmulexp(x, y):
"""
Numerically stable version of (x.exp() @ y.exp()).log().
"""
finfo = torch.finfo(x.dtype)  # avoid nan due to -inf - -inf
x_shift = x.detach().max(-1, keepdim=True).values.clamp_(min=finfo.min)
y_shift = y.detach().max(-2, keepdim=True).values.clamp_(min=finfo.min)
xy = safe_log(torch.matmul((x - x_shift).exp(), (y - y_shift).exp()))
return xy + x_shift + y_shift

# TODO re-enable jitting once _SafeLog is supported by the jit.
# See https://discuss.pytorch.org/t/does-torch-jit-script-support-custom-operators/65759/4
# @torch_jit_script_if_tracing
def _sequential_logmatmulexp(logits):
"""
For a tensor x whose time dimension is -3, computes::

x[..., 0, :, :] @ x[..., 1, :, :] @ ... @ x[..., T-1, :, :]

but does so numerically stably in log space.
"""
batch_shape = logits.shape[:-3]
state_dim = logits.size(-1)
while logits.size(-3) > 1:
time = logits.size(-3)
even_time = time // 2 * 2
even_part = logits[..., :even_time, :, :]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2, state_dim, state_dim))
x, y = x_y.unbind(-3)
contracted = _logmatmulexp(x, y)
if time > even_time:
contracted = torch.cat((contracted, logits[..., -1:, :, :]), dim=-3)
logits = contracted
return logits.squeeze(-3)

def _markov_index(x, y):
"""
Join ends of two Markov paths.
"""
y = Vindex(y.unsqueeze(-2))[..., x[..., -1:, :]]

def _sequential_index(samples):
"""
For a tensor samples whose time dimension is -2 and state dimension
is -1, compute Markov paths by sequential indexing.

For example, for samples with 3 states and time duration 5::

tensor([[0, 1, 1],
[1, 0, 2],
[2, 1, 0],
[0, 2, 1],
[1, 1, 0]])

computed paths are::

tensor([[0, 1, 1],
[1, 0, 0],
[1, 2, 2],
[2, 1, 1],
[0, 1, 1]])

# path for a 0th state
#
# 0 1 1
# |
# 1 0 2
#  \
# 2 1 0
#   |
# 0 2 1
#    \
# 1 1 0
#
# paths for 1st and 2nd states
#
# 0 1 1
#   |/
# 1 0 2
#  /
# 2 1 0
#  \
#    \
# 0 2 1
#    /
# 1 1 0
"""
# new Markov time dimension at -2
samples = samples.unsqueeze(-2)
batch_shape = samples.shape[:-3]
state_dim = samples.size(-1)
duration = samples.size(-3)
while samples.size(-3) > 1:
time = samples.size(-3)
even_time = time // 2 * 2
even_part = samples[..., :even_time, :, :]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2, -1, state_dim))
x, y = x_y.unbind(-3)
contracted = _markov_index(x, y)
if time > even_time:
input=samples[..., -1:, :, :],
pad=(0, 0, 0, contracted.size(-2) // 2),
)
samples = contracted
return samples.squeeze(-3)[..., :duration, :]

def _sequential_gamma_gaussian_tensordot(gamma_gaussian):
"""
Integrates a GammaGaussian x whose rightmost batch dimension is time, computes::

x[..., 0] @ x[..., 1] @ ... @ x[..., T-1]
"""
assert isinstance(gamma_gaussian, GammaGaussian)
assert gamma_gaussian.dim() % 2 == 0, "dim is not even"
batch_shape = gamma_gaussian.batch_shape[:-1]
state_dim = gamma_gaussian.dim() // 2
while gamma_gaussian.batch_shape[-1] > 1:
time = gamma_gaussian.batch_shape[-1]
even_time = time // 2 * 2
even_part = gamma_gaussian[..., :even_time]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2))
x, y = x_y[..., 0], x_y[..., 1]
contracted = gamma_gaussian_tensordot(x, y, state_dim)
if time > even_time:
contracted = GammaGaussian.cat(
(contracted, gamma_gaussian[..., -1:]), dim=-1
)
gamma_gaussian = contracted
return gamma_gaussian[..., 0]

class HiddenMarkovModel(TorchDistribution):
"""
Abstract base class for Hidden Markov Models.

The purpose of this class is to handle duration logic for homogeneous HMMs.

:param int duration: Optional size of the time axis event_shape.
This is required when sampling from homogeneous HMMs whose parameters
are not expanded along the time axis.
"""

def __init__(self, duration, batch_shape, event_shape, validate_args=None):
if duration is None:
if event_shape != 1:
# Infer duration from event_shape.
duration = event_shape
elif duration != event_shape:
if event_shape != 1:
raise ValueError(
"duration, event_shape mismatch: {} vs {}".format(
duration, event_shape
)
)
# Infer event_shape from duration.
event_shape = torch.Size((duration,) + event_shape[1:])
self._duration = duration
super().__init__(batch_shape, event_shape, validate_args)

@property
def duration(self):
"""
Returns the size of the time axis, or None if unknown.
"""
return self._duration

def _validate_sample(self, value):
if value.dim() < self.event_dim:
raise ValueError("value has too few dimensions: {}".format(value.shape))

if self.duration is not None:
super()._validate_sample(value)
return

# Temporarily infer duration from value.shape.
duration = value.size(-self.event_dim)
old = self._event_shape
new = torch.Size((duration,)) + self._event_shape[1:]
try:
self._event_shape = new
super()._validate_sample(value)
finally:
self._event_shape = old

[docs]class DiscreteHMM(HiddenMarkovModel):
"""
Hidden Markov Model with discrete latent state and arbitrary observation
distribution.

This uses  to parallelize over time, achieving O(log(time)) parallel
complexity for computing :meth:log_prob, :meth:filter, and :meth:sample.

The event_shape of this distribution includes time on the left::

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous
time dependency of transition_logits and observation_dist. However,
because time is included in this distribution's event_shape, the
homogeneous+homogeneous case will have a broadcastable event_shape with
num_steps = 1, allowing :meth:log_prob to work with arbitrary length
data::

# homogeneous + homogeneous case:
event_shape = (1,) + observation_dist.event_shape

**References:**

 Simo Sarkka, Angel F. Garcia-Fernandez (2019)
"Temporal Parallelization of Bayesian Filters and Smoothers"
https://arxiv.org/pdf/1905.13002.pdf

:param ~torch.Tensor initial_logits: A logits tensor for an initial
categorical distribution over latent states. Should have rightmost size
state_dim and be broadcastable to batch_shape + (state_dim,).
:param ~torch.Tensor transition_logits: A logits tensor for transition
conditional distributions between latent states. Should have rightmost
shape (state_dim, state_dim) (old, new), and be broadcastable to
batch_shape + (num_steps, state_dim, state_dim).
:param ~torch.distributions.Distribution observation_dist: A conditional
distribution of observed data conditioned on latent state. The
.batch_shape should have rightmost size state_dim and be
broadcastable to batch_shape + (num_steps, state_dim). The
.event_shape may be arbitrary.
:param int duration: Optional size of the time axis event_shape.
This is required when sampling from homogeneous HMMs whose parameters
are not expanded along the time axis.
"""

arg_constraints = {
"initial_logits": constraints.real,
"transition_logits": constraints.real,
}

def __init__(
self,
initial_logits,
transition_logits,
observation_dist,
validate_args=None,
duration=None,
):
if initial_logits.dim() < 1:
raise ValueError(
"expected initial_logits to have at least one dim, "
"actual shape = {}".format(initial_logits.shape)
)
if transition_logits.dim() < 2:
raise ValueError(
"expected transition_logits to have at least two dims, "
"actual shape = {}".format(transition_logits.shape)
)
if len(observation_dist.batch_shape) < 1:
raise ValueError(
"expected observation_dist to have at least one batch dim, "
"actual .batch_shape = {}".format(observation_dist.batch_shape)
)
initial_logits.shape[:-1] + (1,),
transition_logits.shape[:-2],
observation_dist.batch_shape[:-1],
)
batch_shape, time_shape = shape[:-1], shape[-1:]
event_shape = time_shape + observation_dist.event_shape
self.initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
self.transition_logits = transition_logits - transition_logits.logsumexp(
-1, True
)
self.observation_dist = observation_dist
super().__init__(
duration, batch_shape, event_shape, validate_args=validate_args
)

@constraints.dependent_property(event_dim=2)
def support(self):
return constraints.independent(self.observation_dist.support, 1)

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(DiscreteHMM, _instance)
# We only need to expand one of the inputs, since batch_shape is determined
# by broadcasting all three. To save computation in _sequential_logmatmulexp(),
# we expand only initial_logits, which is applied only after the logmatmulexp.
# This is similar to the ._unbroadcasted_* pattern used elsewhere in distributions.
new.initial_logits = self.initial_logits.expand(batch_shape + (-1,))
new.transition_logits = self.transition_logits
new.observation_dist = self.observation_dist
super(DiscreteHMM, new).__init__(
self.duration, batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new

[docs]    def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

# Combine observation and transition factors.
value = value.unsqueeze(-1 - self.observation_dist.event_dim)
observation_logits = self.observation_dist.log_prob(value)
result = self.transition_logits + observation_logits.unsqueeze(-2)

# Eliminate time dimension.
result = _sequential_logmatmulexp(result)

# Combine initial factor.
result = self.initial_logits + result.logsumexp(-1)

# Marginalize out final state.
result = result.logsumexp(-1)
return result

[docs]    def filter(self, value):
"""
Compute posterior over final state given a sequence of observations.

:param ~torch.Tensor value: A sequence of observations.
:return: A posterior distribution
over latent states at the final time step. result.logits can
then be used as initial_logits in a sequential Pyro model for
prediction.
:rtype: ~pyro.distributions.Categorical
"""
if self._validate_args:
self._validate_sample(value)

# Combine observation and transition factors.
value = value.unsqueeze(-1 - self.observation_dist.event_dim)
observation_logits = self.observation_dist.log_prob(value)
logp = self.transition_logits + observation_logits.unsqueeze(-2)

# Eliminate time dimension.
logp = _sequential_logmatmulexp(logp)

# Combine initial factor.
logp = (self.initial_logits.unsqueeze(-1) + logp).logsumexp(-2)

# Convert to a distribution.
return Categorical(logits=logp, validate_args=self._validate_args)

def sample(self, sample_shape=torch.Size()):
assert self.duration is not None

# Sample initial state.
S = self.initial_logits.size(-1)  # state space size
init_shape = torch.Size(sample_shape) + self.batch_shape + (S,)
init_logits = self.initial_logits.expand(init_shape)
x = Categorical(logits=init_logits).sample()

# Sample hidden states over time.
trans_shape = (
torch.Size(sample_shape) + self.batch_shape + (self.duration, S, S)
)
trans_logits = self.transition_logits.expand(trans_shape)
xs = Categorical(logits=trans_logits).sample()
xs = _sequential_index(xs)
x = Vindex(xs)[..., :, x]

# Sample observations conditioned on hidden states.
# Note the simple sample-then-slice approach here generalizes to all
# distributions, but is inefficient. To implement a general optimal
# slice-then-sample strategy would require distributions to support
# slicing https://github.com/pyro-ppl/pyro/issues/3052. A simpler
# implementation might register a few slicing operators as is done with
# pyro.contrib.forecast.util.reshape_batch(). If you as a user need
# this function to be cheaper, feel free to submit a PR implementing
# one of these approaches.
obs_shape = self.batch_shape + (self.duration, S)
obs_dist = self.observation_dist.expand(obs_shape)
y = obs_dist.sample(sample_shape)
y = Vindex(y)[(Ellipsis, x) + (slice(None),) * obs_dist.event_dim]
return y

[docs]class GaussianHMM(HiddenMarkovModel):
"""
Hidden Markov Model with Gaussians for initial, transition, and observation
distributions. This adapts  to parallelize over time to achieve
O(log(time)) parallel complexity, however it differs in that it tracks the
log normalizer to ensure :meth:log_prob is differentiable.

This corresponds to the generative model::

z = initial_distribution.sample()
x = []
for t in range(num_events):
z = z @ transition_matrix + transition_dist.sample()
x.append(z @ observation_matrix + observation_dist.sample())

The event_shape of this distribution includes time on the left::

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous
time dependency of transition_dist and observation_dist. However,
because time is included in this distribution's event_shape, the
homogeneous+homogeneous case will have a broadcastable event_shape with
num_steps = 1, allowing :meth:log_prob to work with arbitrary length
data::

event_shape = (1, obs_dim)  # homogeneous + homogeneous case

**References:**

 Simo Sarkka, Angel F. Garcia-Fernandez (2019)
"Temporal Parallelization of Bayesian Filters and Smoothers"
https://arxiv.org/pdf/1905.13002.pdf

:ivar int hidden_dim: The dimension of the hidden state.
:ivar int obs_dim: The dimension of the observed state.
:param ~torch.distributions.MultivariateNormal initial_dist: A distribution
over initial states. This should have batch_shape broadcastable to
self.batch_shape.  This should have event_shape (hidden_dim,).
:param ~torch.Tensor transition_matrix: A linear transformation of hidden
state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, hidden_dim) where the
rightmost dims are ordered (old, new).
:param ~torch.distributions.MultivariateNormal transition_dist: A process
noise distribution. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,).  This should have event_shape
(hidden_dim,).
:param ~torch.Tensor observation_matrix: A linear transformation from hidden
to observed state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, obs_dim).
:param observation_dist: An observation noise distribution. This should
have batch_shape broadcastable to self.batch_shape + (num_steps,).
This should have event_shape (obs_dim,).
:type observation_dist: ~torch.distributions.MultivariateNormal or
~torch.distributions.Independent of ~torch.distributions.Normal
:param int duration: Optional size of the time axis event_shape.
This is required when sampling from homogeneous HMMs whose parameters
are not expanded along the time axis.
"""

has_rsample = True
arg_constraints = {}
support = constraints.independent(constraints.real, 2)

def __init__(
self,
initial_dist,
transition_matrix,
transition_dist,
observation_matrix,
observation_dist,
validate_args=None,
duration=None,
):
assert isinstance(initial_dist, torch.distributions.MultivariateNormal) or (
isinstance(initial_dist, torch.distributions.Independent)
and isinstance(initial_dist.base_dist, torch.distributions.Normal)
)
assert isinstance(transition_matrix, torch.Tensor)
assert isinstance(transition_dist, torch.distributions.MultivariateNormal) or (
isinstance(transition_dist, torch.distributions.Independent)
and isinstance(transition_dist.base_dist, torch.distributions.Normal)
)
assert isinstance(observation_matrix, torch.Tensor)
assert isinstance(observation_dist, torch.distributions.MultivariateNormal) or (
isinstance(observation_dist, torch.distributions.Independent)
and isinstance(observation_dist.base_dist, torch.distributions.Normal)
)
hidden_dim, obs_dim = observation_matrix.shape[-2:]
assert initial_dist.event_shape == (hidden_dim,)
assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
assert transition_dist.event_shape == (hidden_dim,)
assert observation_dist.event_shape == (obs_dim,)
initial_dist.batch_shape + (1,),
transition_matrix.shape[:-2],
transition_dist.batch_shape,
observation_matrix.shape[:-2],
observation_dist.batch_shape,
)
batch_shape, time_shape = shape[:-1], shape[-1:]
event_shape = time_shape + (obs_dim,)
super().__init__(
duration, batch_shape, event_shape, validate_args=validate_args
)

self.hidden_dim = hidden_dim
self.obs_dim = obs_dim
self._init = mvn_to_gaussian(initial_dist).expand(self.batch_shape)
self._trans = matrix_and_mvn_to_gaussian(transition_matrix, transition_dist)
self._obs = matrix_and_mvn_to_gaussian(observation_matrix, observation_dist)

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(GaussianHMM, _instance)
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
new._obs = self._obs
new._trans = self._trans

# To save computation in sequential_gaussian_tensordot(), we expand
# only _init, which is applied only after
# sequential_gaussian_tensordot().
new._init = self._init.expand(batch_shape)

super(GaussianHMM, new).__init__(
self.duration, batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new

[docs]    def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

# Combine observation and transition factors.
left=self.hidden_dim
)

# Eliminate time dimension.
result = sequential_gaussian_tensordot(result.expand(result.batch_shape))

# Combine initial factor.
result = gaussian_tensordot(self._init, result, dims=self.hidden_dim)

# Marginalize out final state.
result = result.event_logsumexp()
return result

[docs]    def rsample(self, sample_shape=torch.Size()):
assert self.duration is not None
sample_shape = torch.Size(sample_shape)
left=self.hidden_dim
)
trans = trans.expand(trans.batch_shape[:-1] + (self.duration,))
z = sequential_gaussian_filter_sample(self._init, trans, sample_shape)
z = z[..., 1:, :]  # drop the initial hidden state
x = self._obs.left_condition(z).rsample()
return x

[docs]    def rsample_posterior(self, value, sample_shape=torch.Size()):
"""
EXPERIMENTAL Sample from the latent state conditioned on observation.
"""
trans = trans.expand(trans.batch_shape)
z = sequential_gaussian_filter_sample(self._init, trans, sample_shape)
z = z[..., 1:, :]  # drop the initial hidden state
return z

[docs]    def filter(self, value):
"""
Compute posterior over final state given a sequence of observations.

:param ~torch.Tensor value: A sequence of observations.
:return: A posterior
distribution over latent states at the final time step. result
can then be used as initial_dist in a sequential Pyro model for
prediction.
:rtype: ~pyro.distributions.MultivariateNormal
"""
if self._validate_args:
self._validate_sample(value)

# Combine observation and transition factors.

# Eliminate time dimension.
logp = sequential_gaussian_tensordot(logp.expand(logp.batch_shape))

# Combine initial factor.
logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim)

# Convert to a distribution
precision = logp.precision
loc = cholesky_solve(
logp.info_vec.unsqueeze(-1), safe_cholesky(precision)
).squeeze(-1)
return MultivariateNormal(
loc, precision_matrix=precision, validate_args=self._validate_args
)

[docs]    def conjugate_update(self, other):
"""
EXPERIMENTAL Creates an updated :class:GaussianHMM fusing information
from another compatible distribution.

This should satisfy::

fg, log_normalizer = f.conjugate_update(g)
assert f.log_prob(x) + g.log_prob(x) == fg.log_prob(x) + log_normalizer

:param other: A distribution representing p(data|self.probs) but
normalized over self.probs rather than data.
:type other: ~torch.distributions.Independent of
~torch.distributions.MultivariateNormal or ~torch.distributions.Normal
:return: a pair (updated,log_normalizer) where updated is an
updated :class:GaussianHMM , and log_normalizer is a
:class:~torch.Tensor representing the normalization factor.
"""
assert isinstance(other, torch.distributions.Independent) and (
isinstance(other.base_dist, torch.distributions.Normal)
or isinstance(other.base_dist, torch.distributions.MultivariateNormal)
)
duration = other.event_shape if self.duration is None else self.duration
event_shape = torch.Size((duration, self.obs_dim))
assert other.event_shape == event_shape

new = self._get_checked_instance(GaussianHMM)
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
new._init = self._init
new._trans = self._trans
left=self.hidden_dim
)

# Normalize.
# TODO cache this computation for the forward pass of .rsample().
left=new.hidden_dim
)
logp = sequential_gaussian_tensordot(logp.expand(logp.batch_shape))
logp = gaussian_tensordot(new._init, logp, dims=new.hidden_dim)
log_normalizer = logp.event_logsumexp()
new._init = new._init - log_normalizer

batch_shape = log_normalizer.shape
super(GaussianHMM, new).__init__(
duration, batch_shape, event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new, log_normalizer

[docs]    def prefix_condition(self, data):
"""
EXPERIMENTAL Given self has event_shape == (t+f, d) and data x
of shape batch_shape + (t, d), compute a conditional distribution
of event_shape (f, d). Typically t is the number of training
time steps, f is the number of forecast time steps, and d is
the data dimension.

:param data: data of dimension at least 2.
:type data: ~torch.Tensor
"""
assert data.dim() >= 2
assert data.size(-1) == self.event_shape[-1]
assert data.size(-2) < self.duration
t = data.size(-2)
f = self.duration - t

left = self._get_checked_instance(GaussianHMM)
left.hidden_dim = self.hidden_dim
left.obs_dim = self.obs_dim
left._init = self._init

right = self._get_checked_instance(GaussianHMM)
right.hidden_dim = self.hidden_dim
right.obs_dim = self.obs_dim

if self._obs.batch_shape == () or self._obs.batch_shape[-1] == 1:  # homogeneous
left._obs = self._obs
right._obs = self._obs
else:  # heterogeneous
left._obs = self._obs[..., :t]
right._obs = self._obs[..., t:]

if (
self._trans.batch_shape == () or self._trans.batch_shape[-1] == 1
):  # homogeneous
left._trans = self._trans
right._trans = self._trans
else:  # heterogeneous
left._trans = self._trans[..., :t]
right._trans = self._trans[..., t:]

super(GaussianHMM, left).__init__(
t, self.batch_shape, (t, self.obs_dim), validate_args=self._validate_args
)
initial_dist = left.filter(data)
right._init = mvn_to_gaussian(initial_dist)
super(GaussianHMM, right).__init__(
f, batch_shape, (f, self.obs_dim), validate_args=self._validate_args
)
return right

[docs]class GammaGaussianHMM(HiddenMarkovModel):
"""
Hidden Markov Model with the joint distribution of initial state, hidden
state, and observed state is a :class:~pyro.distributions.MultivariateStudentT
distribution along the line of references  and . This adapts 
to parallelize over time to achieve O(log(time)) parallel complexity.

This GammaGaussianHMM class corresponds to the generative model::

s = Gamma(df/2, df/2).sample()
z = scale(initial_dist, s).sample()
x = []
for t in range(num_events):
z = z @ transition_matrix + scale(transition_dist, s).sample()
x.append(z @ observation_matrix + scale(observation_dist, s).sample())

where scale(mvn(loc, precision), s) := mvn(loc, s * precision).

The event_shape of this distribution includes time on the left::

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous
time dependency of transition_dist and observation_dist. However,
because time is included in this distribution's event_shape, the
homogeneous+homogeneous case will have a broadcastable event_shape with
num_steps = 1, allowing :meth:log_prob to work with arbitrary length
data::

event_shape = (1, obs_dim)  # homogeneous + homogeneous case

**References:**

 Simo Sarkka, Angel F. Garcia-Fernandez (2019)
"Temporal Parallelization of Bayesian Filters and Smoothers"
https://arxiv.org/pdf/1905.13002.pdf

 F. J. Giron and J. C. Rojano (1994)
"Bayesian Kalman filtering with elliptically contoured errors"

 Filip Tronarp, Toni Karvonen, and Simo Sarkka (2019)
"Student's t-filters for noise scale estimation"
https://users.aalto.fi/~ssarkka/pub/SPL2019.pdf

:ivar int hidden_dim: The dimension of the hidden state.
:ivar int obs_dim: The dimension of the observed state.
:param Gamma scale_dist: Prior of the mixing distribution.
:param MultivariateNormal initial_dist: A distribution with unit scale mixing
over initial states. This should have batch_shape broadcastable to
self.batch_shape.  This should have event_shape (hidden_dim,).
:param ~torch.Tensor transition_matrix: A linear transformation of hidden
state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, hidden_dim) where the
rightmost dims are ordered (old, new).
:param MultivariateNormal transition_dist: A process noise distribution
with unit scale mixing. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,). This should have event_shape
(hidden_dim,).
:param ~torch.Tensor observation_matrix: A linear transformation from hidden
to observed state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, obs_dim).
:param MultivariateNormal observation_dist: An observation noise distribution
with unit scale mixing. This should have batch_shape broadcastable to
self.batch_shape + (num_steps,).
This should have event_shape (obs_dim,).
:param int duration: Optional size of the time axis event_shape.
This is required when sampling from homogeneous HMMs whose parameters
are not expanded along the time axis.
"""

arg_constraints = {}
support = constraints.independent(constraints.real, 2)

def __init__(
self,
scale_dist,
initial_dist,
transition_matrix,
transition_dist,
observation_matrix,
observation_dist,
validate_args=None,
duration=None,
):
assert isinstance(scale_dist, Gamma)
assert isinstance(initial_dist, MultivariateNormal)
assert isinstance(transition_matrix, torch.Tensor)
assert isinstance(transition_dist, MultivariateNormal)
assert isinstance(observation_matrix, torch.Tensor)
assert isinstance(observation_dist, MultivariateNormal)
hidden_dim, obs_dim = observation_matrix.shape[-2:]
assert initial_dist.event_shape == (hidden_dim,)
assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
assert transition_dist.event_shape == (hidden_dim,)
assert observation_dist.event_shape == (obs_dim,)
scale_dist.batch_shape + (1,),
initial_dist.batch_shape + (1,),
transition_matrix.shape[:-2],
transition_dist.batch_shape,
observation_matrix.shape[:-2],
observation_dist.batch_shape,
)
batch_shape, time_shape = shape[:-1], shape[-1:]
event_shape = time_shape + (obs_dim,)
super().__init__(
duration, batch_shape, event_shape, validate_args=validate_args
)
self.hidden_dim = hidden_dim
self.obs_dim = obs_dim
self._init = gamma_and_mvn_to_gamma_gaussian(scale_dist, initial_dist)
self._trans = matrix_and_mvn_to_gamma_gaussian(
transition_matrix, transition_dist
)
self._obs = matrix_and_mvn_to_gamma_gaussian(
observation_matrix, observation_dist
)

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(GammaGaussianHMM, _instance)
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
# We only need to expand one of the inputs, since batch_shape is determined
# by broadcasting all three. To save computation in sequential_gaussian_tensordot(),
# we expand only _init, which is applied only after sequential_gaussian_tensordot().
new._init = self._init.expand(batch_shape)
new._trans = self._trans
new._obs = self._obs
super(GammaGaussianHMM, new).__init__(
self.duration, batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new

[docs]    def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

# Combine observation and transition factors.
left=self.hidden_dim
)

# Eliminate time dimension.
result = _sequential_gamma_gaussian_tensordot(result.expand(result.batch_shape))

# Combine initial factor.
result = gamma_gaussian_tensordot(self._init, result, dims=self.hidden_dim)

# Marginalize out final state.
result = result.event_logsumexp()

# Marginalize out multiplier.
result = result.logsumexp()
return result

[docs]    def filter(self, value):
"""
Compute posteriors over the multiplier and the final state
given a sequence of observations. The posterior is a pair of
Gamma and MultivariateNormal distributions (i.e. a GammaGaussian
instance).

:param ~torch.Tensor value: A sequence of observations.
:return: A pair of posterior distributions over the mixing and the latent
state at the final time step.
:rtype: a tuple of ~pyro.distributions.Gamma and ~pyro.distributions.MultivariateNormal
"""
if self._validate_args:
self._validate_sample(value)

# Combine observation and transition factors.

# Eliminate time dimension.
logp = _sequential_gamma_gaussian_tensordot(logp.expand(logp.batch_shape))

# Combine initial factor.
logp = gamma_gaussian_tensordot(self._init, logp, dims=self.hidden_dim)

# Posterior of the scale
gamma_dist = logp.event_logsumexp()
scale_post = Gamma(
gamma_dist.concentration, gamma_dist.rate, validate_args=self._validate_args
)
# Conditional of last state on unit scale
scale_tril = safe_cholesky(logp.precision)
loc = cholesky_solve(logp.info_vec.unsqueeze(-1), scale_tril).squeeze(-1)
mvn = MultivariateNormal(
loc, scale_tril=scale_tril, validate_args=self._validate_args
)
return scale_post, mvn

[docs]class LinearHMM(HiddenMarkovModel):
r"""
Hidden Markov Model with linear dynamics and observations and arbitrary
noise for initial, transition, and observation distributions.  Each of
those distributions can be e.g.
:class:~pyro.distributions.MultivariateNormal or
:class:~pyro.distributions.Independent of
:class:~pyro.distributions.Normal,
:class:~pyro.distributions.StudentT, or :class:~pyro.distributions.Stable .
Additionally the observation distribution may be constrained, e.g.
:class:~pyro.distributions.LogNormal

This corresponds to the generative model::

z = initial_distribution.sample()
x = []
for t in range(num_events):
z = z @ transition_matrix + transition_dist.sample()
y = z @ observation_matrix + obs_base_dist.sample()
x.append(obs_transform(y))

where observation_dist is split into obs_base_dist and an optional
obs_transform (defaulting to the identity).

This implements a reparameterized :meth:rsample method but does not
implement a :meth:log_prob method. Derived classes may implement
:meth:log_prob .

Inference without :meth:log_prob can be performed using either
reparameterization with :class:~pyro.infer.reparam.hmm.LinearHMMReparam
or likelihood-free algorithms such as
:class:~pyro.infer.energy_distance.EnergyDistance .  Note that while
stable processes generally require a common shared stability parameter
:math:\alpha , this distribution and the above inference algorithms allow
heterogeneous stability parameters.

The event_shape of this distribution includes time on the left::

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous
time dependency of transition_dist and observation_dist. However at
least one of the distributions or matrices must be expanded to contain the
time dimension.

:ivar int hidden_dim: The dimension of the hidden state.
:ivar int obs_dim: The dimension of the observed state.
:param initial_dist: A distribution over initial states. This should have
batch_shape broadcastable to self.batch_shape.  This should have
event_shape (hidden_dim,).
:param ~torch.Tensor transition_matrix: A linear transformation of hidden
state. This should have shape broadcastable to self.batch_shape +
(num_steps, hidden_dim, hidden_dim) where the rightmost dims are
ordered (old, new).
:param transition_dist: A distribution over process noise. This should have
batch_shape broadcastable to self.batch_shape + (num_steps,).  This
should have event_shape (hidden_dim,).
:param ~torch.Tensor observation_matrix: A linear transformation from hidden
to observed state. This should have shape broadcastable to
self.batch_shape + (num_steps, hidden_dim, obs_dim).
:param observation_dist: A observation noise distribution. This should have
batch_shape broadcastable to self.batch_shape + (num_steps,).  This
should have event_shape (obs_dim,).
:param int duration: Optional size of the time axis event_shape.
This is required when sampling from homogeneous HMMs whose parameters
are not expanded along the time axis.
"""
arg_constraints = {}
support = constraints.independent(constraints.real, 2)
has_rsample = True

def __init__(
self,
initial_dist,
transition_matrix,
transition_dist,
observation_matrix,
observation_dist,
validate_args=None,
duration=None,
):
assert initial_dist.has_rsample
assert initial_dist.event_dim == 1
assert (
isinstance(transition_matrix, torch.Tensor) and transition_matrix.dim() >= 2
)
assert transition_dist.has_rsample
assert transition_dist.event_dim == 1
assert (
isinstance(observation_matrix, torch.Tensor)
and observation_matrix.dim() >= 2
)
assert observation_dist.has_rsample
assert observation_dist.event_dim == 1

hidden_dim, obs_dim = observation_matrix.shape[-2:]
assert initial_dist.event_shape == (hidden_dim,)
assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
assert transition_dist.event_shape == (hidden_dim,)
assert observation_dist.event_shape == (obs_dim,)
initial_dist.batch_shape + (1,),
transition_matrix.shape[:-2],
transition_dist.batch_shape,
observation_matrix.shape[:-2],
observation_dist.batch_shape,
)
batch_shape, time_shape = shape[:-1], shape[-1:]
event_shape = time_shape + (obs_dim,)
super().__init__(
duration, batch_shape, event_shape, validate_args=validate_args
)

# Expand eagerly.
if initial_dist.batch_shape != batch_shape:
initial_dist = initial_dist.expand(batch_shape)
if transition_matrix.shape[:-2] != batch_shape + time_shape:
transition_matrix = transition_matrix.expand(
batch_shape + time_shape + (hidden_dim, hidden_dim)
)
if transition_dist.batch_shape != batch_shape + time_shape:
transition_dist = transition_dist.expand(batch_shape + time_shape)
if observation_matrix.shape[:-2] != batch_shape + time_shape:
observation_matrix = observation_matrix.expand(
batch_shape + time_shape + (hidden_dim, obs_dim)
)
if observation_dist.batch_shape != batch_shape + time_shape:
observation_dist = observation_dist.expand(batch_shape + time_shape)

# Extract observation transforms.
transforms = []
while True:
if isinstance(observation_dist, torch.distributions.Independent):
observation_dist = observation_dist.base_dist
elif isinstance(
observation_dist, torch.distributions.TransformedDistribution
):
transforms = observation_dist.transforms + transforms
observation_dist = observation_dist.base_dist
else:
break
if not observation_dist.event_shape:
observation_dist = Independent(observation_dist, 1)

self.hidden_dim = hidden_dim
self.obs_dim = obs_dim
self.initial_dist = initial_dist
self.transition_matrix = transition_matrix
self.transition_dist = transition_dist
self.observation_matrix = observation_matrix
self.observation_dist = observation_dist
self.transforms = transforms

@constraints.dependent_property(event_dim=2)
def support(self):
return constraints.independent(self.observation_dist.support, 1)

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LinearHMM, _instance)
batch_shape = torch.Size(batch_shape)
time_shape = self.transition_dist.batch_shape[-1:]
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
new.initial_dist = self.initial_dist.expand(batch_shape)
new.transition_matrix = self.transition_matrix.expand(
batch_shape + time_shape + (self.hidden_dim, self.hidden_dim)
)
new.transition_dist = self.transition_dist.expand(batch_shape + time_shape)
new.observation_matrix = self.observation_matrix.expand(
batch_shape + time_shape + (self.hidden_dim, self.obs_dim)
)
new.observation_dist = self.observation_dist.expand(batch_shape + time_shape)
new.transforms = self.transforms
super(LinearHMM, new).__init__(
self.duration, batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new

[docs]    def log_prob(self, value):
raise NotImplementedError("LinearHMM.log_prob() is not implemented")

[docs]    def rsample(self, sample_shape=torch.Size()):
assert self.duration is not None
init = self.initial_dist.rsample(sample_shape)
trans = self.transition_dist.expand(
self.batch_shape + (self.duration,)
).rsample(sample_shape)
obs = self.observation_dist.expand(self.batch_shape + (self.duration,)).rsample(
sample_shape
)
trans_matrix = self.transition_matrix.expand(
self.batch_shape + (self.duration, -1, -1)
)
z = _linear_integrate(init, trans_matrix, trans)
x = (z.unsqueeze(-2) @ self.observation_matrix).squeeze(-2) + obs
for t in self.transforms:
x = t(x)
return x

[docs]class IndependentHMM(TorchDistribution):
"""
Wrapper class to treat a batch of independent univariate HMMs as a single
multivariate distribution. This converts distribution shapes as follows:

+-----------+--------------------+---------------------+
|           |       .batch_shape | .event_shape        |
+===========+====================+=====================+
| base_dist | shape + (obs_dim,) | (duration, 1)       |
+-----------+--------------------+---------------------+
|    result |              shape | (duration, obs_dim) |
+-----------+--------------------+---------------------+

:param HiddenMarkovModel base_dist: A base hidden Markov model instance.
"""

arg_constraints = {}

def __init__(self, base_dist):
assert base_dist.batch_shape
assert base_dist.event_dim == 2
assert base_dist.event_shape[-1] == 1
batch_shape = base_dist.batch_shape[:-1]
event_shape = base_dist.event_shape[:-1] + base_dist.batch_shape[-1:]
super().__init__(batch_shape, event_shape)
self.base_dist = base_dist

@constraints.dependent_property(event_dim=2)
def support(self):
return self.base_dist.support

@property
def has_rsample(self):
return self.base_dist.has_rsample

@property
def duration(self):
return self.base_dist.duration

[docs]    def expand(self, batch_shape, _instance=None):
batch_shape = torch.Size(batch_shape)
new = self._get_checked_instance(IndependentHMM, _instance)
new.base_dist = self.base_dist.expand(
batch_shape + self.base_dist.batch_shape[-1:]
)
super(IndependentHMM, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new

[docs]    def rsample(self, sample_shape=torch.Size()):
base_value = self.base_dist.rsample(sample_shape)
return base_value.squeeze(-1).transpose(-1, -2)

[docs]    def log_prob(self, value):
base_value = value.transpose(-1, -2).unsqueeze(-1)
return self.base_dist.log_prob(base_value).sum(-1)

[docs]class GaussianMRF(TorchDistribution):
"""
Temporal Markov Random Field with Gaussian factors for initial, transition,
and observation distributions. This adapts  to parallelize over time to
achieve O(log(time)) parallel complexity, however it differs in that it
tracks the log normalizer to ensure :meth:log_prob is differentiable.

The event_shape of this distribution includes time on the left::

event_shape = (num_steps,) + observation_dist.event_shape

This distribution supports any combination of homogeneous/heterogeneous
time dependency of transition_dist and observation_dist. However,
because time is included in this distribution's event_shape, the
homogeneous+homogeneous case will have a broadcastable event_shape with
num_steps = 1, allowing :meth:log_prob to work with arbitrary length
data::

event_shape = (1, obs_dim)  # homogeneous + homogeneous case

**References:**

 Simo Sarkka, Angel F. Garcia-Fernandez (2019)
"Temporal Parallelization of Bayesian Filters and Smoothers"
https://arxiv.org/pdf/1905.13002.pdf

:ivar int hidden_dim: The dimension of the hidden state.
:ivar int obs_dim: The dimension of the observed state.
:param ~torch.distributions.MultivariateNormal initial_dist: A distribution
over initial states. This should have batch_shape broadcastable to
self.batch_shape.  This should have event_shape (hidden_dim,).
:param ~torch.distributions.MultivariateNormal transition_dist: A joint
distribution factor over a pair of successive time steps. This should
have batch_shape broadcastable to self.batch_shape + (num_steps,).
This should have event_shape (hidden_dim + hidden_dim,) (old+new).
:param ~torch.distributions.MultivariateNormal observation_dist: A joint
distribution factor over a hidden and an observed state. This should
have batch_shape broadcastable to self.batch_shape + (num_steps,).
This should have event_shape (hidden_dim + obs_dim,).
"""

arg_constraints = {}

def __init__(
self, initial_dist, transition_dist, observation_dist, validate_args=None
):
assert isinstance(initial_dist, torch.distributions.MultivariateNormal)
assert isinstance(transition_dist, torch.distributions.MultivariateNormal)
assert isinstance(observation_dist, torch.distributions.MultivariateNormal)
hidden_dim = initial_dist.event_shape
assert transition_dist.event_shape == hidden_dim + hidden_dim
obs_dim = observation_dist.event_shape - hidden_dim

initial_dist.batch_shape + (1,),
transition_dist.batch_shape,
observation_dist.batch_shape,
)
batch_shape, time_shape = shape[:-1], shape[-1:]
event_shape = time_shape + (obs_dim,)
super().__init__(batch_shape, event_shape, validate_args=validate_args)
self.hidden_dim = hidden_dim
self.obs_dim = obs_dim
self._init = mvn_to_gaussian(initial_dist)
self._trans = mvn_to_gaussian(transition_dist)
self._obs = mvn_to_gaussian(observation_dist)
self._support = constraints.independent(observation_dist.support, 1)

@constraints.dependent_property(event_dim=2)
def support(self):
return self._support

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(GaussianMRF, _instance)
new.hidden_dim = self.hidden_dim
new.obs_dim = self.obs_dim
# We only need to expand one of the inputs, since batch_shape is determined
# by broadcasting all three. To save computation in sequential_gaussian_tensordot(),
# we expand only _init, which is applied only after sequential_gaussian_tensordot().
new._init = self._init.expand(batch_shape)
new._trans = self._trans
new._obs = self._obs
new._support = self._support
super(GaussianMRF, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self.__dict__.get("_validate_args")
return new

[docs]    def log_prob(self, value):
# We compute a normalized distribution as p(obs,hidden) / p(hidden).
logp_oh = self._trans
logp_h = self._trans

# Combine observation and transition factors.
left=self.hidden_dim
)

# Concatenate p(obs,hidden) and p(hidden) into a single Gaussian.
batch_dim = 1 + max(len(self._init.batch_shape) + 1, len(logp_oh.batch_shape))
batch_shape = (1,) * (
batch_dim - len(logp_oh.batch_shape)
) + logp_oh.batch_shape
logp = Gaussian.cat([logp_oh.expand(batch_shape), logp_h.expand(batch_shape)])

# Eliminate time dimension.
logp = sequential_gaussian_tensordot(logp)

# Combine initial factor.
logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim)

# Marginalize out final state.
logp_oh, logp_h = logp.event_logsumexp()
return logp_oh - logp_h  # = log( p(obs,hidden) / p(hidden) )