# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from functools import reduce, singledispatch
from typing import TYPE_CHECKING, Any, FrozenSet, Tuple
from typing_extensions import Self
import pyro
from pyro.distributions.distribution import COERCIONS
from pyro.ops.linalg import ignore_torch_deprecation_warnings
from pyro.poutine.runtime import _PYRO_STACK
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.util import site_is_subsample
# TODO Remove import guard once funsor is a required dependency.
try:
import funsor
from funsor.cnf import Contraction
from funsor.delta import Delta
from funsor.terms import Funsor, Variable
except ImportError:
# Create fake types for singledispatch.
Contraction = type("Contraction", (), {})
Delta = type("Delta", (), {})
Funsor = type("Funsor", (), {})
Variable = type("Variable", (), {})
if TYPE_CHECKING:
from funsor.distribution import Distribution
from pyro.poutine.runtime import Message
@singledispatch
def _substitute(x, subs):
return x
@_substitute.register(str)
def _(x, subs):
return subs.get(x, x)
@_substitute.register(Variable)
def _(x, subs):
return subs.get(x.name, x)
@_substitute.register(tuple)
def _(x, subs):
return tuple(_substitute(part, subs) for part in x)
@singledispatch
def _extract_deltas(f):
raise NotImplementedError("unmatched {}".format(type(f).__name__))
@_extract_deltas.register(Delta)
def _(f):
return f
@_extract_deltas.register(Contraction)
def _(f):
for d in f.terms:
if isinstance(d, Delta):
return d
[docs]class CollapseMessenger(TraceMessenger):
"""
EXPERIMENTAL Collapses all sites in the context by lazily sampling and
attempting to use conjugacy relations. If no conjugacy is known this will
fail. Code using the results of sample sites must be written to accept
Funsors rather than Tensors. This requires ``funsor`` to be installed.
.. warning:: This is not compatible with automatic guessing of
``max_plate_nesting``. If any plates appear within the collapsed
context, you should manually declare ``max_plate_nesting`` to your
inference algorithm (e.g. ``Trace_ELBO(max_plate_nesting=1)``).
"""
_coerce = None
def __init__(self, *args: Any, **kwargs: Any) -> None:
if CollapseMessenger._coerce is None:
import funsor
from funsor.distribution import CoerceDistributionToFunsor
funsor.set_backend("torch")
CollapseMessenger._coerce = CoerceDistributionToFunsor("torch")
self._block = False
super().__init__(*args, **kwargs)
def _process_message(self, msg: "Message") -> None:
if self._block:
return
if site_is_subsample(msg):
return
super()._process_message(msg)
def _pyro_sample(self, msg: "Message") -> None:
# Eagerly convert fn and value to Funsor.
dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]}
dim_to_name.update(self.preserved_plates)
msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name)
if TYPE_CHECKING:
assert isinstance(msg["fn"], Distribution)
domain = msg["fn"].inputs["value"]
if msg["value"] is None:
msg["value"] = funsor.Variable(msg["name"], domain)
else:
msg["value"] = funsor.to_funsor(msg["value"], domain, dim_to_name)
msg["done"] = True
msg["stop"] = True
def _pyro_post_sample(self, msg: "Message") -> None:
if self._block:
return
if site_is_subsample(msg):
return
super()._pyro_post_sample(msg)
def _pyro_barrier(self, msg: "Message") -> None:
# Get log_prob and record factor.
name, log_prob, log_joint, sampled_vars = self._get_log_prob()
self._block = True
pyro.factor(name, log_prob.data)
self._block = False
# Sample
if sampled_vars:
samples = log_joint.sample(sampled_vars)
deltas = _extract_deltas(samples)
samples = {name: point.data for name, (point, _) in deltas.terms}
else:
samples = {}
# Update value.
assert len(msg["args"]) == 1
value = msg["args"][0]
value = _substitute(value, samples)
msg["value"] = value
def __enter__(self) -> Self:
self.preserved_plates = {
h.dim: h.name for h in _PYRO_STACK if isinstance(h, pyro.plate)
}
COERCIONS.append(self._coerce)
return super().__enter__()
def __exit__(self, *args) -> None:
_coerce = COERCIONS.pop()
assert _coerce is self._coerce
super().__exit__(*args)
if any(site["type"] == "sample" for site in self.trace.nodes.values()):
name, log_prob, _, _ = self._get_log_prob()
pyro.factor(name, log_prob.data)
@ignore_torch_deprecation_warnings()
def _get_log_prob(self) -> Tuple[str, Funsor, Funsor, FrozenSet[str]]:
# Convert delayed statements to pyro.factor()
reduced_vars_list = []
log_prob_terms = []
plates: FrozenSet[str] = frozenset()
for name, site in self.trace.nodes.items():
if not site["is_observed"]:
reduced_vars_list.append(name)
log_prob_terms.append(site["fn"](value=site["value"]))
plates |= frozenset(
f.name for f in site["cond_indep_stack"] if f.vectorized
)
name = reduced_vars_list[0]
reduced_vars = frozenset(reduced_vars_list)
assert log_prob_terms, "nothing to collapse"
self.trace.nodes.clear()
reduced_plates = plates - frozenset(self.preserved_plates.values())
if reduced_plates:
log_prob = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
log_prob_terms,
eliminate=reduced_vars | reduced_plates,
plates=plates,
)
log_joint = NotImplemented
else:
log_joint = reduce(funsor.ops.add, log_prob_terms)
log_prob = log_joint.reduce(funsor.ops.logaddexp, reduced_vars)
return name, log_prob, log_joint, reduced_vars