Source code for pyro.infer.reparam.hmm

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pyro.distributions as dist

from .reparam import Reparam


[docs]class LinearHMMReparam(Reparam): """ Auxiliary variable reparameterizer for :class:`~pyro.distributions.LinearHMM` random variables. This defers to component reparameterizers to create auxiliary random variables conditioned on which the process becomes a :class:`~pyro.distributions.GaussianHMM` . If the ``observation_dist`` is a :class:`~pyro.distributions.TransformedDistribution` this reorders those transforms so that the result is a :class:`~pyro.distributions.TransformedDistribution` of :class:`~pyro.distributions.GaussianHMM` . This is useful for training the parameters of a :class:`~pyro.distributions.LinearHMM` distribution, whose :meth:`~pyro.distributions.LinearHMM.log_prob` method is undefined. To perform inference in the presence of non-Gaussian factors such as :meth:`~pyro.distributions.Stable`, :meth:`~pyro.distributions.StudentT` or :meth:`~pyro.distributions.LogNormal` , configure with :class:`~pyro.infer.reparam.studentt.StudentTReparam` , :class:`~pyro.infer.reparam.stable.StableReparam` , :class:`~pyro.infer.reparam.stable.SymmetricStableReparam` , etc. component reparameterizers for ``init``, ``trans``, and ``scale``. For example:: hmm = LinearHMM( init_dist=Stable(1,0,1,0).expand([2]).to_event(1), trans_matrix=torch.eye(2), trans_dist=MultivariateNormal(torch.zeros(2), torch.eye(2)), obs_matrix=torch.eye(2), obs_dist=TransformedDistribution( Stable(1.5,-0.5,1.0).expand([2]).to_event(1), ExpTransform())) rep = LinearHMMReparam(init=SymmetricStableReparam(), obs=StableReparam()) with poutine.reparam(config={"hmm": rep}): pyro.sample("hmm", hmm, obs=data) :param init: Optional reparameterizer for the initial distribution. :type init: ~pyro.infer.reparam.reparam.Reparam :param trans: Optional reparameterizer for the transition distribution. :type trans: ~pyro.infer.reparam.reparam.Reparam :param obs: Optional reparameterizer for the observation distribution. :type obs: ~pyro.infer.reparam.reparam.Reparam """ def __init__(self, init=None, trans=None, obs=None): assert init is None or isinstance(init, Reparam) assert trans is None or isinstance(trans, Reparam) assert obs is None or isinstance(obs, Reparam) self.init = init self.trans = trans self.obs = obs
[docs] def __call__(self, name, fn, obs): # Reparameterize the initial distribution as conditionally Gaussian. init_dist = fn.initial_dist if self.init is not None: init_dist, _ = self.init("{}_init".format(name), init_dist, None) # Reparameterize the transition distribution as conditionally Gaussian. trans_dist = fn.transition_dist if self.trans is not None: trans_dist, _ = self.trans("{}_trans".format(name), trans_dist.to_event(1), None) trans_dist = trans_dist.to_event(-1) # Reparameterize the observation distribution as conditionally Gaussian. obs_dist = fn.observation_dist if self.obs is not None: obs_dist, obs = self.obs("{}_obs".format(name), obs_dist.to_event(1), obs) obs_dist = obs_dist.to_event(-1) # Reparameterize the entire HMM as conditionally Gaussian. hmm = dist.GaussianHMM(init_dist, fn.transition_matrix, trans_dist, fn.observation_matrix, obs_dist) # Apply any observation transforms. if fn.transforms: hmm = dist.TransformedDistribution(hmm, fn.transforms) return hmm, obs