Source code for pyro.poutine.seed_messenger

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from types import TracebackType
from typing import Optional, Type

from pyro.poutine.messenger import Messenger
from pyro.util import get_rng_state, set_rng_seed, set_rng_state


[docs]class SeedMessenger(Messenger): """ Handler to set the random number generator to a pre-defined state by setting its seed. This is the same as calling :func:`pyro.set_rng_seed` before the call to `fn`. This handler has no additional effect on primitive statements on the standard Pyro backend, but it might intercept ``pyro.sample`` calls in other backends. e.g. the NumPy backend. :param fn: a stochastic function (callable containing Pyro primitive calls). :param int rng_seed: rng seed. """ def __init__(self, rng_seed: int) -> None: assert isinstance(rng_seed, int) self.rng_seed = rng_seed super().__init__() def __enter__(self) -> None: # type: ignore[override] self.old_state = get_rng_state() set_rng_seed(self.rng_seed) def __exit__( self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: set_rng_state(self.old_state)