Source code for pyro.infer.reparam.softmax

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch

import pyro
import pyro.distributions as dist

from .reparam import Reparam

[docs]class GumbelSoftmaxReparam(Reparam): """ Reparametrizer for :class:`~pyro.distributions.RelaxedOneHotCategorical` latent variables. This is useful for transforming multimodal posteriors to unimodal posteriors. Note this increases the latent dimension by 1 per event. This reparameterization works only for latent variables, not likelihoods. """
[docs] def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] fn, event_dim = self._unwrap(fn) assert isinstance(fn, dist.RelaxedOneHotCategorical) if is_observed: raise NotImplementedError( "SoftmaxReparam does not support observe statements" f" (at sample site {repr(name)})" ) # Differentiably invert transform. u = None if value is not None: logits = value * fn.temperature u = (fn.logits - logits).exp().neg().exp() # Draw parameter-free noise. proto = fn.logits new_fn = dist.Uniform(torch.zeros_like(proto), torch.ones_like(proto)) u = pyro.sample( f"{name}_uniform", self._wrap(new_fn, event_dim), obs=u, infer={"is_observed": is_observed}, ) # Differentiably transform. if value is None: logits = fn.logits - u.log().neg().log() value = (logits / fn.temperature).softmax(dim=-1) # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim).mask(False) return {"fn": new_fn, "value": value, "is_observed": True}