Source code for pyro.infer.reparam.hmm

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pyro.distributions as dist

from .reparam import Reparam


[docs]class LinearHMMReparam(Reparam): """ Auxiliary variable reparameterizer for :class:`~pyro.distributions.LinearHMM` random variables. This defers to component reparameterizers to create auxiliary random variables conditioned on which the process becomes a :class:`~pyro.distributions.GaussianHMM` . If the ``observation_dist`` is a :class:`~pyro.distributions.TransformedDistribution` this reorders those transforms so that the result is a :class:`~pyro.distributions.TransformedDistribution` of :class:`~pyro.distributions.GaussianHMM` . This is useful for training the parameters of a :class:`~pyro.distributions.LinearHMM` distribution, whose :meth:`~pyro.distributions.LinearHMM.log_prob` method is undefined. To perform inference in the presence of non-Gaussian factors such as :meth:`~pyro.distributions.Stable`, :meth:`~pyro.distributions.StudentT` or :meth:`~pyro.distributions.LogNormal` , configure with :class:`~pyro.infer.reparam.studentt.StudentTReparam` , :class:`~pyro.infer.reparam.stable.StableReparam` , :class:`~pyro.infer.reparam.stable.SymmetricStableReparam` , etc. component reparameterizers for ``init``, ``trans``, and ``scale``. For example:: hmm = LinearHMM( init_dist=Stable(1,0,1,0).expand([2]).to_event(1), trans_matrix=torch.eye(2), trans_dist=MultivariateNormal(torch.zeros(2), torch.eye(2)), obs_matrix=torch.eye(2), obs_dist=TransformedDistribution( Stable(1.5,-0.5,1.0).expand([2]).to_event(1), ExpTransform())) rep = LinearHMMReparam(init=SymmetricStableReparam(), obs=StableReparam()) with poutine.reparam(config={"hmm": rep}): pyro.sample("hmm", hmm, obs=data) :param init: Optional reparameterizer for the initial distribution. :type init: ~pyro.infer.reparam.reparam.Reparam :param trans: Optional reparameterizer for the transition distribution. :type trans: ~pyro.infer.reparam.reparam.Reparam :param obs: Optional reparameterizer for the observation distribution. :type obs: ~pyro.infer.reparam.reparam.Reparam """ def __init__(self, init=None, trans=None, obs=None): assert init is None or isinstance(init, Reparam) assert trans is None or isinstance(trans, Reparam) assert obs is None or isinstance(obs, Reparam) self.init = init self.trans = trans self.obs = obs
[docs] def __call__(self, name, fn, obs): fn, event_dim = self._unwrap(fn) assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM)) if fn.duration is None: raise ValueError("LinearHMMReparam requires duration to be specified " "on targeted LinearHMM distributions") # Unwrap IndependentHMM. if isinstance(fn, dist.IndependentHMM): if obs is not None: obs = obs.transpose(-1, -2).unsqueeze(-1) hmm, obs = self(name, fn.base_dist.to_event(1), obs) hmm = dist.IndependentHMM(hmm.to_event(-1)) if obs is not None: obs = obs.squeeze(-1).transpose(-1, -2) return hmm, obs # Reparameterize the initial distribution as conditionally Gaussian. init_dist = fn.initial_dist if self.init is not None: init_dist, _ = self.init("{}_init".format(name), self._wrap(init_dist, event_dim - 1), None) init_dist = init_dist.to_event(1 - init_dist.event_dim) # Reparameterize the transition distribution as conditionally Gaussian. trans_dist = fn.transition_dist if self.trans is not None: if trans_dist.batch_shape[-1] != fn.duration: trans_dist = trans_dist.expand(trans_dist.batch_shape[:-1] + (fn.duration,)) trans_dist, _ = self.trans("{}_trans".format(name), self._wrap(trans_dist, event_dim), None) trans_dist = trans_dist.to_event(1 - trans_dist.event_dim) # Reparameterize the observation distribution as conditionally Gaussian. obs_dist = fn.observation_dist if self.obs is not None: if obs_dist.batch_shape[-1] != fn.duration: obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] + (fn.duration,)) obs_dist, obs = self.obs("{}_obs".format(name), self._wrap(obs_dist, event_dim), obs) obs_dist = obs_dist.to_event(1 - obs_dist.event_dim) # Reparameterize the entire HMM as conditionally Gaussian. hmm = dist.GaussianHMM(init_dist, fn.transition_matrix, trans_dist, fn.observation_matrix, obs_dist, duration=fn.duration) hmm = self._wrap(hmm, event_dim) # Apply any observation transforms. if fn.transforms: hmm = dist.TransformedDistribution(hmm, fn.transforms) return hmm, obs