Source code for pyro.poutine.substitute_messenger

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import TYPE_CHECKING, Dict, Set

from typing_extensions import Self

from pyro import params
from pyro.poutine.messenger import Messenger
from pyro.poutine.util import is_validation_enabled

if TYPE_CHECKING:
    import torch

    from pyro.poutine.runtime import Message


[docs]class SubstituteMessenger(Messenger): """ Given a stochastic function with param calls and a set of parameter values, create a stochastic function where all param calls are substituted with the fixed values. data should be a dict of names to values. Consider the following Pyro program: >>> def model(x): ... a = pyro.param("a", torch.tensor(0.5)) ... x = pyro.sample("x", dist.Bernoulli(probs=a)) ... return x >>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)}) In this example, site `a` will now have value `torch.tensor(0.3)`. :param data: dictionary of values keyed by site names. :returns: ``fn`` decorated with a :class:`~pyro.poutine.substitute_messenger.SubstituteMessenger` """ def __init__(self, data: Dict[str, "torch.Tensor"]) -> None: """ :param data: values for the parameters. Constructor """ super().__init__() self.data = data self._data_cache: Dict[str, "Message"] = {} def __enter__(self) -> Self: self._data_cache = {} if is_validation_enabled() and isinstance(self.data, dict): self._param_hits: Set[str] = set() self._param_misses: Set[str] = set() return super().__enter__() def __exit__(self, *args, **kwargs) -> None: self._data_cache = {} if is_validation_enabled() and isinstance(self.data, dict): extra = set(self.data) - self._param_hits if extra: warnings.warn( "pyro.module data did not find params ['{}']. " "Did you instead mean one of ['{}']?".format( "', '".join(extra), "', '".join(self._param_misses) ) ) return super().__exit__(*args, **kwargs) def _pyro_sample(self, msg: "Message") -> None: return None def _pyro_param(self, msg: "Message") -> None: """ Overrides the `pyro.param` with substituted values. If the param name does not match the name the keys in `data`, that param value is unchanged. """ assert msg["name"] is not None name = msg["name"] param_name = params.user_param_name(name) if param_name in self.data.keys(): msg["value"] = self.data[param_name] if is_validation_enabled(): self._param_hits.add(param_name) else: if is_validation_enabled(): self._param_misses.add(param_name) return None if name in self._data_cache: # Multiple pyro.param statements with the same # name. Block the site and fix the value. msg["value"] = self._data_cache[name]["value"] else: self._data_cache[name] = msg