Source code for pyro.poutine.util

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import TYPE_CHECKING, List, Optional

from pyro import settings

if TYPE_CHECKING:
    from pyro.distributions.distribution import Distribution
    from pyro.poutine.runtime import Message
    from pyro.poutine.trace_struct import Trace

_VALIDATION_ENABLED = __debug__
settings.register("validate_poutine", __name__, "_VALIDATION_ENABLED")


[docs]def enable_validation(is_validate: bool) -> None: global _VALIDATION_ENABLED _VALIDATION_ENABLED = is_validate
[docs]def is_validation_enabled() -> bool: return _VALIDATION_ENABLED
[docs]def site_is_subsample(site: "Message") -> bool: """ Determines whether a trace site originated from a subsample statement inside an `plate`. """ return site["type"] == "sample" and type(site["fn"]).__name__ == "_Subsample"
[docs]def site_is_factor(site: "Message") -> bool: """ Determines whether a trace site originated from a factor statement. """ return site["type"] == "sample" and type(site["fn"]).__name__ == "Unit"
[docs]def prune_subsample_sites(trace: "Trace") -> "Trace": """ Copies and removes all subsample sites from a trace. """ trace = trace.copy() for name, site in list(trace.nodes.items()): if site_is_subsample(site): trace.remove_node(name) return trace
[docs]def enum_extend( trace: "Trace", msg: "Message", num_samples: Optional[int] = None ) -> List["Trace"]: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site :param num_samples: maximum number of extended traces to return. :returns: a list of traces, copies of input trace with one extra site Utility function to copy and extend a trace with sites based on the input site whose values are enumerated from the support of the input site's distribution. Used for exact inference and integrating out discrete variables. """ if num_samples is None: num_samples = -1 extended_traces = [] assert msg["name"] is not None if TYPE_CHECKING: assert isinstance(msg["fn"], Distribution) for i, s in enumerate(msg["fn"].enumerate_support(*msg["args"], **msg["kwargs"])): if i > num_samples and num_samples >= 0: break msg_copy = msg.copy() msg_copy.update(value=s) # type: ignore[call-arg] tr_cp = trace.copy() tr_cp.add_node(msg["name"], **msg_copy) extended_traces.append(tr_cp) return extended_traces
[docs]def mc_extend( trace: "Trace", msg: "Message", num_samples: Optional[int] = None ) -> List["Trace"]: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site :param num_samples: maximum number of extended traces to return. :returns: a list of traces, copies of input trace with one extra site Utility function to copy and extend a trace with sites based on the input site whose values are sampled from the input site's function. Used for Monte Carlo marginalization of individual sample sites. """ if num_samples is None: num_samples = 1 extended_traces = [] for i in range(num_samples): msg_copy = msg.copy() msg_copy["value"] = msg_copy["fn"](*msg_copy["args"], **msg_copy["kwargs"]) tr_cp = trace.copy() assert msg_copy["name"] is not None tr_cp.add_node(msg_copy["name"], **msg_copy) extended_traces.append(tr_cp) return extended_traces
[docs]def discrete_escape(trace: "Trace", msg: "Message") -> bool: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site :returns: boolean decision value Utility function that checks if a sample site is discrete and not already in a trace. Used by EscapeMessenger to decide whether to do a nonlocal exit at a site. Subroutine for integrating out discrete variables for variance reduction. """ return ( msg["type"] == "sample" and not msg["is_observed"] and msg["name"] is not None and msg["name"] not in trace and getattr(msg["fn"], "has_enumerate_support", False) )
[docs]def all_escape(trace: "Trace", msg: "Message") -> bool: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site :returns: boolean decision value Utility function that checks if a site is not already in a trace. Used by EscapeMessenger to decide whether to do a nonlocal exit at a site. Subroutine for approximately integrating out variables for variance reduction. """ return ( msg["type"] == "sample" and not msg["is_observed"] and msg["name"] is not None and msg["name"] not in trace )