Source code for pyro.poutine.do_messenger

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import numbers
import warnings
from typing import Dict, Union

import torch

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message, apply_stack


[docs]class DoMessenger(Messenger): """ Given a stochastic function with some sample statements and a dictionary of values at names, set the return values of those sites equal to the values as if they were hard-coded to those values and introduce fresh sample sites with the same names whose values do not propagate. Composes freely with :func:`~pyro.poutine.handlers.condition` to represent counterfactual distributions over potential outcomes. See Single World Intervention Graphs [1] for additional details and theory. Consider the following Pyro program: >>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2 To intervene with a value for site `z`, we can write >>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)}) This is equivalent to replacing `z = pyro.sample("z", ...)` with `z = torch.tensor(1.)` and introducing a fresh sample site pyro.sample("z", ...) whose value is not used elsewhere. References [1] `Single World Intervention Graphs: A Primer`, Thomas Richardson, James Robins :param fn: a stochastic function (callable containing Pyro primitive calls) :param data: a ``dict`` mapping sample site names to interventions :returns: stochastic function decorated with a :class:`~pyro.poutine.do_messenger.DoMessenger` """ def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]) -> None: super().__init__() self.data = data self._intervener_id = str(id(self)) def _pyro_sample(self, msg: Message) -> None: assert isinstance(msg["name"], str) if ( msg.get("_intervener_id") != self._intervener_id and self.data.get(msg["name"]) is not None ): if msg.get("_intervener_id") is not None: warnings.warn( "Attempting to intervene on variable {} multiple times," "this is almost certainly incorrect behavior".format(msg["name"]), RuntimeWarning, ) msg["_intervener_id"] = self._intervener_id # split node, avoid reapplying self recursively to new node new_msg = msg.copy() new_msg["cond_indep_stack"] = () # avoid entering plates twice apply_stack(new_msg) # apply intervention intervention = self.data[msg["name"]] msg["name"] = msg["name"] + "__CF" # mangle old name if isinstance(intervention, numbers.Number): msg["value"] = torch.tensor(intervention) msg["is_observed"] = True msg["stop"] = True elif isinstance(intervention, torch.Tensor): msg["value"] = intervention msg["is_observed"] = True msg["stop"] = True else: raise NotImplementedError( "Interventions of type {} not implemented (yet)".format( type(intervention) ) )