# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import TYPE_CHECKING, Callable, List, Optional
from pyro.poutine.messenger import Messenger
if TYPE_CHECKING:
from pyro.poutine.runtime import Message
def _block_fn(
expose: List[str],
expose_types: List[str],
hide: List[str],
hide_types: List[str],
hide_all: bool,
msg: "Message",
) -> bool:
# handle observes
if msg["type"] == "sample" and msg["is_observed"]:
msg_type = "observe"
else:
msg_type = msg["type"]
is_not_exposed = (msg["name"] not in expose) and (msg_type not in expose_types)
# decision rule for hiding:
if (
(msg["name"] in hide)
or (msg_type in hide_types)
or (is_not_exposed and hide_all)
): # noqa: E129
return True
# otherwise expose
else:
return False
def _make_default_hide_fn(
hide_all: bool,
expose_all: bool,
hide: Optional[List[str]],
expose: Optional[List[str]],
hide_types: Optional[List[str]],
expose_types: Optional[List[str]],
) -> Callable[["Message"], bool]:
# first, some sanity checks:
# hide_all and expose_all intersect?
assert (hide_all is False and expose_all is False) or (
hide_all != expose_all
), "cannot hide and expose a site"
# hide and expose intersect?
if hide is None:
hide = []
else:
hide_all = False
if expose is None:
expose = []
else:
hide_all = True
assert set(hide).isdisjoint(set(expose)), "cannot hide and expose a site"
# hide_types and expose_types intersect?
if hide_types is None:
hide_types = []
else:
hide_all = False
if expose_types is None:
expose_types = []
else:
hide_all = True
assert set(hide_types).isdisjoint(
set(expose_types)
), "cannot hide and expose a site type"
return partial(_block_fn, expose, expose_types, hide, hide_types, hide_all)
def _negate_fn(
fn: Callable[["Message"], Optional[bool]]
) -> Callable[["Message"], bool]:
# typed version of lambda msg: not fn(msg)
def negated_fn(msg: "Message") -> bool:
return not fn(msg)
return negated_fn
[docs]class BlockMessenger(Messenger):
"""
This handler selectively hides Pyro primitive sites from the outside world.
Default behavior: block everything.
A site is hidden if at least one of the following holds:
0. ``hide_fn(msg) is True`` or ``(not expose_fn(msg)) is True``
1. ``msg["name"] in hide``
2. ``msg["type"] in hide_types``
3. ``msg["name"] not in expose and msg["type"] not in expose_types``
4. ``hide``, ``hide_types``, and ``expose_types`` are all ``None``
For example, suppose the stochastic function fn has two sample sites "a" and "b".
Then any effect outside of ``BlockMessenger(fn, hide=["a"])``
will not be applied to site "a" and will only see site "b":
>>> def fn():
... a = pyro.sample("a", dist.Normal(0., 1.))
... return pyro.sample("b", dist.Normal(a, 1.))
>>> fn_inner = pyro.poutine.trace(fn)
>>> fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"]))
>>> trace_inner = fn_inner.get_trace()
>>> trace_outer = fn_outer.get_trace()
>>> "a" in trace_inner
True
>>> "a" in trace_outer
False
>>> "b" in trace_inner
True
>>> "b" in trace_outer
True
:param fn: a stochastic function (callable containing Pyro primitive calls)
:param hide_fn: function that takes a site and returns True to hide the site
or False/None to expose it. If specified, all other parameters are ignored.
Only specify one of hide_fn or expose_fn, not both.
:param expose_fn: function that takes a site and returns True to expose the site
or False/None to hide it. If specified, all other parameters are ignored.
Only specify one of hide_fn or expose_fn, not both.
:param bool hide_all: hide all sites
:param bool expose_all: expose all sites normally
:param list hide: list of site names to hide
:param list expose: list of site names to be exposed while all others hidden
:param list hide_types: list of site types to be hidden
:param list expose_types: list of site types to be exposed while all others hidden
:returns: stochastic function decorated with a :class:`~pyro.poutine.block_messenger.BlockMessenger`
"""
def __init__(
self,
hide_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
expose_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
hide_all: bool = True,
expose_all: bool = False,
hide: Optional[List[str]] = None,
expose: Optional[List[str]] = None,
hide_types: Optional[List[str]] = None,
expose_types: Optional[List[str]] = None,
) -> None:
super().__init__()
if not (hide_fn is None or expose_fn is None):
raise ValueError("Only specify one of hide_fn or expose_fn")
if hide_fn is not None:
self.hide_fn = hide_fn
elif expose_fn is not None:
self.hide_fn = _negate_fn(expose_fn)
else:
self.hide_fn = _make_default_hide_fn(
hide_all, expose_all, hide, expose, hide_types, expose_types
)
def _process_message(self, msg: "Message") -> None:
msg["stop"] = bool(self.hide_fn(msg))