# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import numbers
import warnings
from typing import Dict, Union
import torch
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message, apply_stack
[docs]class DoMessenger(Messenger):
"""
Given a stochastic function with some sample statements
and a dictionary of values at names,
set the return values of those sites equal to the values
as if they were hard-coded to those values
and introduce fresh sample sites with the same names
whose values do not propagate.
Composes freely with :func:`~pyro.poutine.handlers.condition`
to represent counterfactual distributions over potential outcomes.
See Single World Intervention Graphs [1] for additional details and theory.
Consider the following Pyro program:
>>> def model(x):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(x, s))
... return z ** 2
To intervene with a value for site `z`, we can write
>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})
This is equivalent to replacing `z = pyro.sample("z", ...)` with
`z = torch.tensor(1.)`
and introducing a fresh sample site pyro.sample("z", ...) whose value is not used elsewhere.
References
[1] `Single World Intervention Graphs: A Primer`,
Thomas Richardson, James Robins
:param fn: a stochastic function (callable containing Pyro primitive calls)
:param data: a ``dict`` mapping sample site names to interventions
:returns: stochastic function decorated with a :class:`~pyro.poutine.do_messenger.DoMessenger`
"""
def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]) -> None:
super().__init__()
self.data = data
self._intervener_id = str(id(self))
def _pyro_sample(self, msg: Message) -> None:
assert isinstance(msg["name"], str)
if (
msg.get("_intervener_id") != self._intervener_id
and self.data.get(msg["name"]) is not None
):
if msg.get("_intervener_id") is not None:
warnings.warn(
"Attempting to intervene on variable {} multiple times,"
"this is almost certainly incorrect behavior".format(msg["name"]),
RuntimeWarning,
)
msg["_intervener_id"] = self._intervener_id
# split node, avoid reapplying self recursively to new node
new_msg = msg.copy()
new_msg["cond_indep_stack"] = () # avoid entering plates twice
apply_stack(new_msg)
# apply intervention
intervention = self.data[msg["name"]]
msg["name"] = msg["name"] + "__CF" # mangle old name
if isinstance(intervention, numbers.Number):
msg["value"] = torch.tensor(intervention)
msg["is_observed"] = True
msg["stop"] = True
elif isinstance(intervention, torch.Tensor):
msg["value"] = intervention
msg["is_observed"] = True
msg["stop"] = True
else:
raise NotImplementedError(
"Interventions of type {} not implemented (yet)".format(
type(intervention)
)
)