Source code for pyro.poutine.reentrant_messenger

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import functools
from types import TracebackType
from typing import Callable, Optional, Type, TypeVar

from typing_extensions import ParamSpec, Self

from pyro.poutine.messenger import Messenger

_P = ParamSpec("_P")
_T = TypeVar("_T")


[docs]class ReentrantMessenger(Messenger): def __init__(self) -> None: self._ref_count = 0 super().__init__() def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]: return functools.wraps(fn)(super().__call__(fn)) def __enter__(self) -> Self: self._ref_count += 1 if self._ref_count == 1: super().__enter__() return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: self._ref_count -= 1 if self._ref_count == 0: super().__exit__(exc_type, exc_value, traceback)