Source code for pyro.infer.reparam.studentt

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pyro
import pyro.distributions as dist

from .reparam import Reparam


[docs]class StudentTReparam(Reparam): """ Auxiliary variable reparameterizer for :class:`~pyro.distributions.StudentT` random variables. This is useful in combination with :class:`~pyro.infer.reparam.hmm.LinearHMMReparam` because it allows StudentT processes to be treated as conditionally Gaussian processes, permitting cheap inference via :class:`~pyro.distributions.GaussianHMM` . This reparameterizes a :class:`~pyro.distributions.StudentT` by introducing an auxiliary :class:`~pyro.distributions.Gamma` variable conditioned on which the result is :class:`~pyro.distributions.Normal` . """
[docs] def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] fn, event_dim = self._unwrap(fn) assert isinstance(fn, dist.StudentT) # Draw a sample that depends only on df. half_df = fn.df * 0.5 gamma = pyro.sample( "{}_gamma".format(name), self._wrap(dist.Gamma(half_df, half_df), event_dim) ) # Construct a scaled Normal. loc = fn.loc scale = fn.scale * gamma.rsqrt() new_fn = self._wrap(dist.Normal(loc, scale), event_dim) return {"fn": new_fn, "value": value, "is_observed": is_observed}