# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, List, Optional
from pyro import settings
if TYPE_CHECKING:
from pyro.distributions.distribution import Distribution
from pyro.poutine.runtime import Message
from pyro.poutine.trace_struct import Trace
_VALIDATION_ENABLED = __debug__
settings.register("validate_poutine", __name__, "_VALIDATION_ENABLED")
[docs]def enable_validation(is_validate: bool) -> None:
global _VALIDATION_ENABLED
_VALIDATION_ENABLED = is_validate
[docs]def is_validation_enabled() -> bool:
return _VALIDATION_ENABLED
[docs]def site_is_subsample(site: "Message") -> bool:
"""
Determines whether a trace site originated from a subsample statement inside an `plate`.
"""
return site["type"] == "sample" and type(site["fn"]).__name__ == "_Subsample"
[docs]def site_is_factor(site: "Message") -> bool:
"""
Determines whether a trace site originated from a factor statement.
"""
return site["type"] == "sample" and type(site["fn"]).__name__ == "Unit"
[docs]def prune_subsample_sites(trace: "Trace") -> "Trace":
"""
Copies and removes all subsample sites from a trace.
"""
trace = trace.copy()
for name, site in list(trace.nodes.items()):
if site_is_subsample(site):
trace.remove_node(name)
return trace
[docs]def enum_extend(
trace: "Trace", msg: "Message", num_samples: Optional[int] = None
) -> List["Trace"]:
"""
:param trace: a partial trace
:param msg: the message at a Pyro primitive site
:param num_samples: maximum number of extended traces to return.
:returns: a list of traces, copies of input trace with one extra site
Utility function to copy and extend a trace with sites based on the input site
whose values are enumerated from the support of the input site's distribution.
Used for exact inference and integrating out discrete variables.
"""
if num_samples is None:
num_samples = -1
extended_traces = []
assert msg["name"] is not None
if TYPE_CHECKING:
assert isinstance(msg["fn"], Distribution)
for i, s in enumerate(msg["fn"].enumerate_support(*msg["args"], **msg["kwargs"])):
if i > num_samples and num_samples >= 0:
break
msg_copy = msg.copy()
msg_copy.update(value=s) # type: ignore[call-arg]
tr_cp = trace.copy()
tr_cp.add_node(msg["name"], **msg_copy)
extended_traces.append(tr_cp)
return extended_traces
[docs]def mc_extend(
trace: "Trace", msg: "Message", num_samples: Optional[int] = None
) -> List["Trace"]:
"""
:param trace: a partial trace
:param msg: the message at a Pyro primitive site
:param num_samples: maximum number of extended traces to return.
:returns: a list of traces, copies of input trace with one extra site
Utility function to copy and extend a trace with sites based on the input site
whose values are sampled from the input site's function.
Used for Monte Carlo marginalization of individual sample sites.
"""
if num_samples is None:
num_samples = 1
extended_traces = []
for i in range(num_samples):
msg_copy = msg.copy()
msg_copy["value"] = msg_copy["fn"](*msg_copy["args"], **msg_copy["kwargs"])
tr_cp = trace.copy()
assert msg_copy["name"] is not None
tr_cp.add_node(msg_copy["name"], **msg_copy)
extended_traces.append(tr_cp)
return extended_traces
[docs]def discrete_escape(trace: "Trace", msg: "Message") -> bool:
"""
:param trace: a partial trace
:param msg: the message at a Pyro primitive site
:returns: boolean decision value
Utility function that checks if a sample site is discrete and not already in a trace.
Used by EscapeMessenger to decide whether to do a nonlocal exit at a site.
Subroutine for integrating out discrete variables for variance reduction.
"""
return (
msg["type"] == "sample"
and not msg["is_observed"]
and msg["name"] is not None
and msg["name"] not in trace
and getattr(msg["fn"], "has_enumerate_support", False)
)
[docs]def all_escape(trace: "Trace", msg: "Message") -> bool:
"""
:param trace: a partial trace
:param msg: the message at a Pyro primitive site
:returns: boolean decision value
Utility function that checks if a site is not already in a trace.
Used by EscapeMessenger to decide whether to do a nonlocal exit at a site.
Subroutine for approximately integrating out variables for variance reduction.
"""
return (
msg["type"] == "sample"
and not msg["is_observed"]
and msg["name"] is not None
and msg["name"] not in trace
)