Source code for pyro.poutine.equalize_messenger

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import re
from typing import List, Optional, Union

from typing_extensions import Self

from pyro.distributions import Delta
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message


[docs]class EqualizeMessenger(Messenger): """ Given a stochastic function with some primitive statements and a list of names, force the primitive statements at those names to have the same value, with that value being the result of the first primitive statement matching those names. Consider the following Pyro program: >>> def per_category_model(category): ... shift = pyro.param(f'{category}_shift', torch.randn(1)) ... mean = pyro.sample(f'{category}_mean', pyro.distributions.Normal(0, 1)) ... std = pyro.sample(f'{category}_std', pyro.distributions.LogNormal(0, 1)) ... return pyro.sample(f'{category}_values', pyro.distributions.Normal(mean + shift, std)) Running the program for multiple categories can be done by >>> def model(categories): ... return {category:per_category_model(category) for category in categories} To make the `std` sample sites have the same value, we can write >>> equal_std_model = pyro.poutine.equalize(model, '.+_std') If on top of the above we would like to make the 'shift' parameters identical, we can write >>> equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param') Alternatively, the ``equalize`` messenger can be used to condition a model on primitive statements having the same value by setting `keep_dist` to True. Consider the below model: >>> def model(): ... x = pyro.sample('x', pyro.distributions.Normal(0, 1)) ... y = pyro.sample('y', pyro.distributions.Normal(5, 3)) ... return x, y The model can be conditioned on 'x' and 'y' having the same value by >>> conditioned_model = pyro.poutine.equalize(model, ['x', 'y'], keep_dist=True) Note that the conditioned model defined above calculates the correct unnormalized log-probablity density, but in order to correctly sample from it one must use SVI or MCMC techniques. :param fn: a stochastic function (callable containing Pyro primitive calls) :param sites: a string or list of strings to match site names (the strings can be regular expressions) :param type: a string specifying the site type (default is 'sample') :param bool keep_dist: Whether to keep the distributions of the second and subsequent matching primitive statements. If kept this is equivalent to conditioning the model on all matching primitive statements having the same value, as opposed to having the second and subsequent matching primitive statements replaced by delta sampling functions. Defaults to False. :returns: stochastic function decorated with a :class:`~pyro.poutine.equalize_messenger.EqualizeMessenger` """ def __init__( self, sites: Union[str, List[str]], type: Optional[str] = "sample", keep_dist: Optional[bool] = False, ) -> None: super().__init__() self.sites = [sites] if isinstance(sites, str) else sites self.type = type self.keep_dist = keep_dist def __enter__(self) -> Self: self.value = None return super().__enter__() def _is_matching(self, msg: Message) -> bool: if msg["type"] == self.type: for site in self.sites: if re.compile(site).fullmatch(msg["name"]) is not None: # type: ignore[arg-type] return True return False def _postprocess_message(self, msg: Message) -> None: if self.value is None and self._is_matching(msg): value = msg["value"] assert value is not None self.value = value def _process_message(self, msg: Message) -> None: if self.value is not None and self._is_matching(msg): # type: ignore[unreachable] msg["value"] = self.value # type: ignore[unreachable] if msg["type"] == "sample": msg["is_observed"] = True if not self.keep_dist: msg["infer"] = {"_deterministic": True} msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask( False )