Source code for pyro.poutine.collapse_messenger

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


from functools import reduce, singledispatch
from typing import TYPE_CHECKING, Any, FrozenSet, Tuple

from typing_extensions import Self

import pyro
from pyro.distributions.distribution import COERCIONS
from pyro.ops.linalg import ignore_torch_deprecation_warnings
from pyro.poutine.runtime import _PYRO_STACK
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.util import site_is_subsample

# TODO Remove import guard once funsor is a required dependency.
try:
    import funsor
    from funsor.cnf import Contraction
    from funsor.delta import Delta
    from funsor.terms import Funsor, Variable
except ImportError:
    # Create fake types for singledispatch.
    Contraction = type("Contraction", (), {})
    Delta = type("Delta", (), {})
    Funsor = type("Funsor", (), {})
    Variable = type("Variable", (), {})

if TYPE_CHECKING:
    from funsor.distribution import Distribution

    from pyro.poutine.runtime import Message


@singledispatch
def _substitute(x, subs):
    return x


@_substitute.register(str)
def _(x, subs):
    return subs.get(x, x)


@_substitute.register(Variable)
def _(x, subs):
    return subs.get(x.name, x)


@_substitute.register(tuple)
def _(x, subs):
    return tuple(_substitute(part, subs) for part in x)


@singledispatch
def _extract_deltas(f):
    raise NotImplementedError("unmatched {}".format(type(f).__name__))


@_extract_deltas.register(Delta)
def _(f):
    return f


@_extract_deltas.register(Contraction)
def _(f):
    for d in f.terms:
        if isinstance(d, Delta):
            return d


[docs]class CollapseMessenger(TraceMessenger): """ EXPERIMENTAL Collapses all sites in the context by lazily sampling and attempting to use conjugacy relations. If no conjugacy is known this will fail. Code using the results of sample sites must be written to accept Funsors rather than Tensors. This requires ``funsor`` to be installed. .. warning:: This is not compatible with automatic guessing of ``max_plate_nesting``. If any plates appear within the collapsed context, you should manually declare ``max_plate_nesting`` to your inference algorithm (e.g. ``Trace_ELBO(max_plate_nesting=1)``). """ _coerce = None def __init__(self, *args: Any, **kwargs: Any) -> None: if CollapseMessenger._coerce is None: import funsor from funsor.distribution import CoerceDistributionToFunsor funsor.set_backend("torch") CollapseMessenger._coerce = CoerceDistributionToFunsor("torch") self._block = False super().__init__(*args, **kwargs) def _process_message(self, msg: "Message") -> None: if self._block: return if site_is_subsample(msg): return super()._process_message(msg) def _pyro_sample(self, msg: "Message") -> None: # Eagerly convert fn and value to Funsor. dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]} dim_to_name.update(self.preserved_plates) msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name) if TYPE_CHECKING: assert isinstance(msg["fn"], Distribution) domain = msg["fn"].inputs["value"] if msg["value"] is None: msg["value"] = funsor.Variable(msg["name"], domain) else: msg["value"] = funsor.to_funsor(msg["value"], domain, dim_to_name) msg["done"] = True msg["stop"] = True def _pyro_post_sample(self, msg: "Message") -> None: if self._block: return if site_is_subsample(msg): return super()._pyro_post_sample(msg) def _pyro_barrier(self, msg: "Message") -> None: # Get log_prob and record factor. name, log_prob, log_joint, sampled_vars = self._get_log_prob() self._block = True pyro.factor(name, log_prob.data) self._block = False # Sample if sampled_vars: samples = log_joint.sample(sampled_vars) deltas = _extract_deltas(samples) samples = {name: point.data for name, (point, _) in deltas.terms} else: samples = {} # Update value. assert len(msg["args"]) == 1 value = msg["args"][0] value = _substitute(value, samples) msg["value"] = value def __enter__(self) -> Self: self.preserved_plates = { h.dim: h.name for h in _PYRO_STACK if isinstance(h, pyro.plate) } COERCIONS.append(self._coerce) return super().__enter__() def __exit__(self, *args) -> None: _coerce = COERCIONS.pop() assert _coerce is self._coerce super().__exit__(*args) if any(site["type"] == "sample" for site in self.trace.nodes.values()): name, log_prob, _, _ = self._get_log_prob() pyro.factor(name, log_prob.data) @ignore_torch_deprecation_warnings() def _get_log_prob(self) -> Tuple[str, Funsor, Funsor, FrozenSet[str]]: # Convert delayed statements to pyro.factor() reduced_vars_list = [] log_prob_terms = [] plates: FrozenSet[str] = frozenset() for name, site in self.trace.nodes.items(): if not site["is_observed"]: reduced_vars_list.append(name) log_prob_terms.append(site["fn"](value=site["value"])) plates |= frozenset( f.name for f in site["cond_indep_stack"] if f.vectorized ) name = reduced_vars_list[0] reduced_vars = frozenset(reduced_vars_list) assert log_prob_terms, "nothing to collapse" self.trace.nodes.clear() reduced_plates = plates - frozenset(self.preserved_plates.values()) if reduced_plates: log_prob = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, log_prob_terms, eliminate=reduced_vars | reduced_plates, plates=plates, ) log_joint = NotImplemented else: log_joint = reduce(funsor.ops.add, log_prob_terms) log_prob = log_joint.reduce(funsor.ops.logaddexp, reduced_vars) return name, log_prob, log_joint, reduced_vars