Source code for pyro.infer.reparam.hmm

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pyro.distributions as dist

from .reparam import Reparam

[docs]class LinearHMMReparam(Reparam): """ Auxiliary variable reparameterizer for :class:`~pyro.distributions.LinearHMM` random variables. This defers to component reparameterizers to create auxiliary random variables conditioned on which the process becomes a :class:`~pyro.distributions.GaussianHMM` . If the ``observation_dist`` is a :class:`~pyro.distributions.TransformedDistribution` this reorders those transforms so that the result is a :class:`~pyro.distributions.TransformedDistribution` of :class:`~pyro.distributions.GaussianHMM` . This is useful for training the parameters of a :class:`~pyro.distributions.LinearHMM` distribution, whose :meth:`~pyro.distributions.LinearHMM.log_prob` method is undefined. To perform inference in the presence of non-Gaussian factors such as :meth:`~pyro.distributions.Stable`, :meth:`~pyro.distributions.StudentT` or :meth:`~pyro.distributions.LogNormal` , configure with :class:`~pyro.infer.reparam.studentt.StudentTReparam` , :class:`~pyro.infer.reparam.stable.StableReparam` , :class:`~pyro.infer.reparam.stable.SymmetricStableReparam` , etc. component reparameterizers for ``init``, ``trans``, and ``scale``. For example:: hmm = LinearHMM( init_dist=Stable(1,0,1,0).expand([2]).to_event(1), trans_matrix=torch.eye(2), trans_dist=MultivariateNormal(torch.zeros(2), torch.eye(2)), obs_matrix=torch.eye(2), obs_dist=TransformedDistribution( Stable(1.5,-0.5,1.0).expand([2]).to_event(1), ExpTransform())) rep = LinearHMMReparam(init=SymmetricStableReparam(), obs=StableReparam()) with poutine.reparam(config={"hmm": rep}): pyro.sample("hmm", hmm, obs=data) :param init: Optional reparameterizer for the initial distribution. :type init: ~pyro.infer.reparam.reparam.Reparam :param trans: Optional reparameterizer for the transition distribution. :type trans: ~pyro.infer.reparam.reparam.Reparam :param obs: Optional reparameterizer for the observation distribution. :type obs: ~pyro.infer.reparam.reparam.Reparam """ def __init__(self, init=None, trans=None, obs=None): assert init is None or isinstance(init, Reparam) assert trans is None or isinstance(trans, Reparam) assert obs is None or isinstance(obs, Reparam) self.init = init self.trans = trans self.obs = obs
[docs] def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] fn, event_dim = self._unwrap(fn) assert isinstance(fn, (dist.LinearHMM, dist.IndependentHMM)) if fn.duration is None: raise ValueError( "LinearHMMReparam requires duration to be specified " "on targeted LinearHMM distributions" ) # Unwrap IndependentHMM. if isinstance(fn, dist.IndependentHMM): indep_value = None if value is not None: indep_value = value.transpose(-1, -2).unsqueeze(-1) msg = self.apply( { "name": name, "fn": fn.base_dist.to_event(1), "value": indep_value, "is_observed": is_observed, } ) hmm = msg["fn"] hmm = dist.IndependentHMM(hmm.to_event(-1)) if msg["value"] is not indep_value: value = msg["value"].squeeze(-1).transpose(-1, -2) return {"fn": hmm, "value": value, "is_observed": is_observed} # Reparameterize the initial distribution as conditionally Gaussian. init_dist = fn.initial_dist if self.init is not None: msg = self.init.apply( { "name": f"{name}_init", "fn": self._wrap(init_dist, event_dim - 1), "value": None, "is_observed": False, } ) init_dist = msg["fn"] init_dist = init_dist.to_event(1 - init_dist.event_dim) # Reparameterize the transition distribution as conditionally Gaussian. trans_dist = fn.transition_dist if self.trans is not None: if trans_dist.batch_shape[-1] != fn.duration: trans_dist = trans_dist.expand( trans_dist.batch_shape[:-1] + (fn.duration,) ) msg = self.trans.apply( { "name": f"{name}_trans", "fn": self._wrap(trans_dist, event_dim), "value": None, "is_observed": False, } ) trans_dist = msg["fn"] trans_dist = trans_dist.to_event(1 - trans_dist.event_dim) # Reparameterize the observation distribution as conditionally Gaussian. obs_dist = fn.observation_dist if self.obs is not None: if obs_dist.batch_shape[-1] != fn.duration: obs_dist = obs_dist.expand(obs_dist.batch_shape[:-1] + (fn.duration,)) msg = self.obs.apply( { "name": f"{name}_obs", "fn": self._wrap(obs_dist, event_dim), "value": value, "is_observed": is_observed, } ) obs_dist = msg["fn"] obs_dist = obs_dist.to_event(1 - obs_dist.event_dim) value = msg["value"] is_observed = msg["is_observed"] # Reparameterize the entire HMM as conditionally Gaussian. hmm = dist.GaussianHMM( init_dist, fn.transition_matrix, trans_dist, fn.observation_matrix, obs_dist, duration=fn.duration, ) hmm = self._wrap(hmm, event_dim) # Apply any observation transforms. if fn.transforms: hmm = dist.TransformedDistribution(hmm, fn.transforms) return {"fn": hmm, "value": value, "is_observed": is_observed}