Source code for pyro.poutine.block_messenger

from functools import partial

from pyro.poutine.messenger import Messenger


def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg):
    # handle observes
    if msg["type"] == "sample" and msg["is_observed"]:
        msg_type = "observe"
    else:
        msg_type = msg["type"]

    is_not_exposed = (msg["name"] not in expose) and \
                     (msg_type not in expose_types)

    # decision rule for hiding:
    if (msg["name"] in hide) or \
            (msg_type in hide_types) or \
            (is_not_exposed and hide_all):  # noqa: E129

        return True
    # otherwise expose
    else:
        return False


def _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose_types):
    # first, some sanity checks:
    # hide_all and expose_all intersect?
    assert (hide_all is False and expose_all is False) or \
        (hide_all != expose_all), "cannot hide and expose a site"

    # hide and expose intersect?
    if hide is None:
        hide = []
    else:
        hide_all = False

    if expose is None:
        expose = []
    else:
        hide_all = True

    assert set(hide).isdisjoint(set(expose)), \
        "cannot hide and expose a site"

    # hide_types and expose_types intersect?
    if hide_types is None:
        hide_types = []
    else:
        hide_all = False

    if expose_types is None:
        expose_types = []
    else:
        hide_all = True

    assert set(hide_types).isdisjoint(set(expose_types)), \
        "cannot hide and expose a site type"

    return partial(_block_fn, expose, expose_types, hide, hide_types, hide_all)


[docs]class BlockMessenger(Messenger): """ This Messenger selectively hides Pyro primitive sites from the outside world. Default behavior: block everything. BlockMessenger has a flexible interface that allows users to specify in several different ways which sites should be hidden or exposed. A site is hidden if at least one of the following holds: 0. ``hide_fn(msg) is True`` or ``(not expose_fn(msg)) is True`` 1. ``msg["name"] in hide`` 2. ``msg["type"] in hide_types`` 3. ``msg["name"] not in expose and msg["type"] not in expose_types`` 4. ``hide``, ``hide_types``, and ``expose_types`` are all ``None`` For example, suppose the stochastic function fn has two sample sites "a" and "b". Then any poutine outside of BlockMessenger(fn, hide=["a"]) will not be applied to site "a" and will only see site "b": .. doctest:: :hide: >>> from pyro.poutine.trace_messenger import TraceMessenger >>> def fn(): ... a = pyro.sample("a", dist.Normal(0., 1.)) ... return pyro.sample("b", dist.Normal(a, 1.)) >>> fn_inner = TraceMessenger()(fn) >>> fn_outer = TraceMessenger()(BlockMessenger(hide=["a"])(TraceMessenger()(fn))) >>> trace_inner = fn_inner.get_trace() >>> trace_outer = fn_outer.get_trace() >>> "a" in trace_inner True >>> "a" in trace_outer False >>> "b" in trace_inner True >>> "b" in trace_outer True See the constructor for details. :param hide_fn: function that takes a site and returns True to hide the site or False/None to expose it. If specified, all other parameters are ignored. Only specify one of hide_fn or expose_fn, not both. :param expose_fn: function that takes a site and returns True to expose the site or False/None to hide it. If specified, all other parameters are ignored. Only specify one of hide_fn or expose_fn, not both. :param bool hide_all: hide all sites :param bool expose_all: expose all sites normally :param list hide: list of site names to hide, rest will be exposed normally :param list expose: list of site names to expose, rest will be hidden :param list hide_types: list of site types to hide, rest will be exposed normally :param list expose_types: list of site types to expose normally, rest will be hidden """ def __init__(self, hide_fn=None, expose_fn=None, hide_all=True, expose_all=False, hide=None, expose=None, hide_types=None, expose_types=None): super(BlockMessenger, self).__init__() if not (hide_fn is None or expose_fn is None): raise ValueError("Only specify one of hide_fn or expose_fn") if hide_fn is not None: self.hide_fn = hide_fn elif expose_fn is not None: self.hide_fn = lambda msg: not expose_fn(msg) else: self.hide_fn = _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose_types) def _process_message(self, msg): msg["stop"] = bool(self.hide_fn(msg)) return None