Source code for pyro.poutine.collapse_messenger

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from functools import reduce, singledispatch

import pyro
from pyro.poutine.util import site_is_subsample
from pyro.distributions.distribution import COERCIONS

from .runtime import _PYRO_STACK
from .trace_messenger import TraceMessenger

# TODO Remove import guard once funsor is a required dependency.
    import funsor
    from funsor.cnf import Contraction
    from import Delta
    from funsor.terms import Funsor, Variable
except ImportError:
    # Create fake types for singledispatch.
    Contraction = type("Contraction", (), {})
    Delta = type("Delta", (), {})
    Funsor = type("Funsor", (), {})
    Variable = type("Variable", (), {})

def _get_free_vars(x):
    return x

def _(x):
    return frozenset((,))

def _(x, subs):
    return frozenset().union(*map(_get_free_vars, x))

def _substitute(x, subs):
    return x

def _(x, subs):
    return subs.get(x, x)

def _(x, subs):
    return subs.get(, x)

def _(x, subs):
    return tuple(_substitute(part, subs) for part in x)

def _extract_deltas(f):
    raise NotImplementedError("unmatched {}".format(type(f).__name__))

def _(f):
    return f

def _(f):
    for d in f.terms:
        if isinstance(d, Delta):
            return d

[docs]class CollapseMessenger(TraceMessenger): """ EXPERIMENTAL Collapses all sites in the context by lazily sampling and attempting to use conjugacy relations. If no conjugacy is known this will fail. Code using the results of sample sites must be written to accept Funsors rather than Tensors. This requires ``funsor`` to be installed. """ _coerce = None def __init__(self, *args, **kwargs): if CollapseMessenger._coerce is None: import funsor from funsor.distribution import CoerceDistributionToFunsor funsor.set_backend("torch") CollapseMessenger._coerce = CoerceDistributionToFunsor("torch") self._block = False super().__init__(*args, **kwargs) def _process_message(self, msg): if self._block: return if site_is_subsample(msg): return super()._process_message(msg) # Block sample statements. if msg["type"] == "sample": if isinstance(msg["fn"], Funsor) or isinstance(msg["value"], (str, Funsor)): msg["stop"] = True def _pyro_sample(self, msg): if self._block: return if msg["value"] is None: msg["value"] = msg["name"] msg["done"] = True def _pyro_post_sample(self, msg): if self._block: return if site_is_subsample(msg): return super()._pyro_post_sample(msg) def _pyro_barrier(self, msg): # Get log_prob and record factor. name, log_prob, log_joint, sampled_vars = self._get_log_prob() self._block = True pyro.factor(name, self._block = False # Sample if sampled_vars: samples = log_joint.sample(sampled_vars) deltas = _extract_deltas(samples) samples = {name: for name, (point, _) in deltas.terms} else: samples = {} # Update value. assert len(msg["args"]) == 1 value = msg["args"][0] value = _substitute(value, samples) msg["value"] = value def __enter__(self): self.preserved_plates = frozenset( for h in _PYRO_STACK if isinstance(h, pyro.plate)) COERCIONS.append(self._coerce) return super().__enter__() def __exit__(self, *args): _coerce = COERCIONS.pop() assert _coerce is self._coerce super().__exit__(*args) if any(site["type"] == "sample" for site in self.trace.nodes.values()): name, log_prob, _, _ = self._get_log_prob() pyro.factor(name, def _get_log_prob(self): # Convert delayed statements to pyro.factor() reduced_vars = [] log_prob_terms = [] plates = frozenset() for name, site in self.trace.nodes.items(): if not site["is_observed"]: reduced_vars.append(name) dim_to_name = {f.dim: for f in site["cond_indep_stack"]} fn = funsor.to_funsor(site["fn"], funsor.Real, dim_to_name) value = site["value"] if not isinstance(value, str): value = funsor.to_funsor(site["value"], fn.inputs["value"], dim_to_name) log_prob_terms.append(fn(value=value)) plates |= frozenset( for f in site["cond_indep_stack"] if f.vectorized) name = reduced_vars[0] reduced_vars = frozenset(reduced_vars) assert log_prob_terms, "nothing to collapse" self.trace.nodes.clear() reduced_plates = plates - self.preserved_plates if reduced_plates: log_prob = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, log_prob_terms, eliminate=reduced_vars | reduced_plates, plates=plates, ) log_joint = NotImplemented else: log_joint = reduce(funsor.ops.add, log_prob_terms) log_prob = log_joint.reduce(funsor.ops.logaddexp, reduced_vars) return name, log_prob, log_joint, reduced_vars