Poutine (Effect handlers)¶
Beneath the builtin inference algorithms, Pyro has a library of composable effect handlers for creating new inference algorithms and working with probabilistic programs. Pyro’s inference algorithms are all built by applying these handlers to stochastic functions.
Handlers¶
Poutine is a library of composable effect handlers for recording and modifying the behavior of Pyro programs. These lowerlevel ingredients simplify the implementation of new inference algorithms and behavior.
Handlers can be used as higherorder functions, decorators, or context managers to modify the behavior of functions or blocks of code:
For example, consider the following Pyro program:
>>> def model(x):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(x, s))
... return z ** 2
We can mark sample sites as observed using condition
,
which returns a callable with the same input and output signatures as model
:
>>> conditioned_model = poutine.condition(model, data={"z": 1.0})
We can also use handlers as decorators:
>>> @pyro.condition(data={"z": 1.0})
... def model(x):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(x, s))
... return z ** 2
Or as context managers:
>>> with pyro.condition(data={"z": 1.0}):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(x, s))
... y = z ** 2
Handlers compose freely:
>>> conditioned_model = poutine.condition(model, data={"z": 1.0})
>>> traced_model = poutine.trace(conditioned_model)
Many inference algorithms or algorithmic components can be implemented in just a few lines of code:
>>> guide_tr = poutine.trace(guide).get_trace(...)
>>> model_tr = poutine.trace(poutine.replay(conditioned_model, trace=tr)).get_trace(...)
>>> monte_carlo_elbo = model_tr.log_prob_sum()  guide_tr.log_prob_sum()

block
(fn=None, hide=None, expose=None, hide_types=None, expose_types=None)[source]¶ This handler selectively hides pyro primitive sites from the outside world. Default behavior: block everything
A site is hidden if at least one of the following holds:
msg["name"] in hide
msg["type"] in hide_types
msg["name"] not in expose and msg["type"] not in expose_types
hide
,hide_types
, andexpose_types
are allNone
For example, suppose the stochastic function fn has two sample sites “a” and “b”. Then any effect outside of
BlockMessenger(fn, hide=["a"])
will not be applied to site “a” and will only see site “b”:>>> fn_inner = trace(fn) >>> fn_outer = trace(block(fn_inner, hide=["a"])) >>> trace_inner = fn_inner.get_trace() >>> trace_outer = fn_outer.get_trace() >>> "a" in trace_inner True >>> "a" in trace_outer False >>> "b" in trace_inner True >>> "b" in trace_outer True
Parameters:  fn – a stochastic function (callable containing pyro primitive calls)
 hide – list of site names to hide
 expose – list of site names to be exposed while all others hidden
 hide_types – list of site types to be hidden
 expose_types – list of site types to be exposed while all others hidden
Returns: stochastic function decorated with a
BlockMessenger

condition
(fn=None, data=None)[source]¶ Given a stochastic function with some sample statements and a dictionary of observations at names, change the sample statements at those names into observes with those values
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
To observe a value for site z, we can write
>>> conditioned_model = condition(model, data={"z": value})
This is equivalent to adding obs=value as a keyword argument to pyro.sample(“z”, …) in model.
Parameters:  fn – a stochastic function (callable containing pyro primitive calls)
 data – a dict or a
Trace
Returns: stochastic function decorated with a
ConditionMessenger

do
(fn=None, data=None)[source]¶ Given a stochastic function with some sample statements and a dictionary of values at names, set the return values of those sites equal to the values and hide them from the rest of the stack as if they were hardcoded to those values by using
block
.Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
To intervene with a value for site z, we can write
>>> intervened_model = do(model, data={"z": value})
This is equivalent to replacing z = pyro.sample(“z”, …) with z = value.
Parameters:  fn – a stochastic function (callable containing pyro primitive calls)
 data – a
dict
or aTrace
Returns: stochastic function decorated with a
BlockMessenger
andpyro.poutine.condition_messenger.ConditionMessenger

enum
(fn=None, first_available_dim=None)[source]¶ Enumerates in parallel over discrete sample sites marked
infer={"enumerate": "parallel"}
.Parameters: first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro.

escape
(fn=None, escape_fn=None)[source]¶ Given a callable that contains Pyro primitive calls, evaluate escape_fn on each site, and if the result is True, raise a
NonlocalExit
exception that stops execution and returns the offending site.Parameters:  fn – a stochastic function (callable containing pyro primitive calls)
 escape_fn – function that takes a partial trace and a site, and returns a boolean value to decide whether to exit at that site
Returns: stochastic function decorated with
EscapeMessenger

indep
(fn=None, name=None, size=None, dim=None)[source]¶ Note
Lowlevel; use
iarange
instead.This messenger keeps track of stack of independence information declared by nested
irange
andiarange
contexts. This information is stored in acond_indep_stack
at each sample/observe site for consumption byTraceMessenger
.

infer_config
(fn=None, config_fn=None)[source]¶ Given a callable that contains Pyro primitive calls and a callable taking a trace site and returning a dictionary, updates the value of the infer kwarg at a sample site to config_fn(site)
Parameters:  fn – a stochastic function (callable containing pyro primitive calls)
 config_fn – a callable taking a site and returning an infer dict
Returns: stochastic function decorated with
InferConfigMessenger

lift
(fn=None, prior=None)[source]¶ Given a stochastic function with param calls and a prior distribution, create a stochastic function where all param calls are replaced by sampling from prior. Prior should be a callable or a dict of names to callables.
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2 >>> lifted_model = lift(model, prior={"s": dist.Exponential(0.3)})
lift
makesparam
statements behave likesample
statements using the distributions inprior
. In this example, site s will now behave as if it was replaced withs = pyro.sample("s", dist.Exponential(0.3))
:>>> tr = trace(lifted_model).get_trace(0.0) >>> tr.nodes["s"]["type"] == "sample" True >>> tr2 = trace(lifted_model).get_trace(0.0) >>> (tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all() False
Parameters:  fn – function whose parameters will be lifted to random values
 prior – prior function in the form of a Distribution or a dict of stochastic fns
Returns: fn
decorated with aLiftMessenger

replay
(fn=None, trace=None, sites=None)[source]¶ Given a callable that contains Pyro primitive calls, return a callable that runs the original, reusing the values at sites in trace at those sites in the new trace
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
replay
makessample
statements behave as if they had sampled the values at the corresponding sites in the trace:>>> replayed_model = replay(model, trace=old_trace) >>> replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"] True
Parameters:  fn – a stochastic function (callable containing pyro primitive calls)
 trace – a
Trace
data structure to replay against  sites – list or dict of names of sample sites in fn to replay against, defaulting to all sites
Returns: a stochastic function decorated with a
ReplayMessenger

queue
(fn=None, queue=None, max_tries=None, extend_fn=None, escape_fn=None, num_samples=None)[source]¶ Used in sequential enumeration over discrete variables.
Given a stochastic function and a queue, return a return value from a complete trace in the queue
Parameters:  fn – a stochastic function (callable containing pyro primitive calls)
 queue – a queue data structure like multiprocessing.Queue to hold partial traces
 max_tries – maximum number of attempts to compute a single complete trace
 extend_fn – function (possibly stochastic) that takes a partial trace and a site, and returns a list of extended traces
 escape_fn – function (possibly stochastic) that takes a partial trace and a site, and returns a boolean value to decide whether to exit
 num_samples – optional number of extended traces for extend_fn to return
Returns: stochastic function decorated with poutine logic

scale
(fn=None, scale=None)[source]¶ Given a stochastic function with some sample statements and a positive scale factor, scale the score of all sample and observe sites in the function.
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s), obs=1.0) ... return z ** 2
scale
multiplicatively scales the logprobabilities of sample sites:>>> scaled_model = scale(model, scale=0.5) >>> scaled_tr = trace(scaled_model).get_trace(0.0) >>> unscaled_tr = trace(model).get_trace(0.0) >>> (scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all() True
Parameters:  fn – a stochastic function (callable containing Pyro primitive calls)
 scale – a positive scaling factor
Returns: stochastic function decorated with a
ScaleMessenger

trace
(fn=None, graph_type=None, param_only=None)[source]¶ Return a handler that records the inputs and outputs of primitive calls and their dependencies.
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
We can record its execution using
trace
and use the resulting data structure to compute the logjoint probability of all of the sample sites in the execution or extract all parameters.>>> trace = pyro.poutine.trace(model).get_trace(0.0) >>> logp = trace.log_prob_sum() >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
Parameters:  fn – a stochastic function (callable containing pyro primitive calls)
 graph_type – string that specifies the kind of graph to construct
 param_only – if true, only records params and not samples
Returns: stochastic function decorated with a
TraceMessenger
Trace¶

class
Trace
(*args, **kwargs)[source]¶ Bases:
object
Execution trace data structure built on top of
networkx.DiGraph
.An execution trace of a Pyro program is a record of every call to
pyro.sample()
andpyro.param()
in a single execution of that program. Traces are directed graphs whose nodes represent primitive calls or input/output, and whose edges represent conditional dependence relationships between those primitive calls. They are created and populated bypoutine.trace
.Each node (or site) in a trace contains the name, input and output value of the site, as well as additional metadata added by inference algorithms or user annotation. In the case of
pyro.sample
, the trace also includes the stochastic function at the site, and any observed data added by users.Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
We can record its execution using
pyro.poutine.trace
and use the resulting data structure to compute the logjoint probability of all of the sample sites in the execution or extract all parameters.>>> trace = pyro.poutine.trace(model).get_trace(0.0) >>> logp = trace.log_prob_sum() >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
We can also inspect or manipulate individual nodes in the trace.
trace.nodes
contains acollections.OrderedDict
of site names and metadata corresponding tox
,s
,z
, and the return value:>>> print(list(name for name in trace.nodes.keys())) ["_INPUT", "s", "z", "_RETURN"]
As in
networkx.DiGraph
, values oftrace.nodes
are dictionaries of node metadata:>>> print(trace.nodes["z"]) {'type': 'sample', 'name': 'z', 'is_observed': False, 'fn': Normal(), 'value': tensor(0.6480), 'args': (), 'kwargs': {}, 'infer': {}, 'scale': 1.0, 'cond_indep_stack': (), 'done': True, 'stop': False, 'continuation': None}
'infer'
is a dictionary of user or algorithmspecified metadata.'args'
and'kwargs'
are the arguments passed viapyro.sample
tofn.__call__
orfn.log_prob
.'scale'
is used to scale the logprobability of the site when computing the logjoint.'cond_indep_stack'
contains data structures corresponding topyro.iarange
contexts appearing in the execution.'done'
,'stop'
, and'continuation'
are only used by Pyro’s internals.
add_edge
¶ Identical to
networkx.DiGraph.remove_edge()

add_node
(site_name, *args, **kwargs)[source]¶ Parameters: site_name (string) – the name of the site to be added Adds a site to the trace.
Identical to
networkx.DiGraph.add_node()
but raises an error when attempting to add a duplicate node instead of silently overwriting.

compute_log_prob
(site_filter=<function <lambda>>)[source]¶ Compute the sitewise log probabilities of the trace. Each
log_prob
has shape equal to the correspondingbatch_shape
. Eachlog_prob_sum
is a scalar. Both computations are memoized.

compute_score_parts
()[source]¶ Compute the batched local score parts at each site of the trace. Each
log_prob
has shape equal to the correspondingbatch_shape
. Eachlog_prob_sum
is a scalar. All computations are memoized.

copy
()[source]¶ Makes a shallow copy of self with nodes and edges preserved. Identical to
networkx.DiGraph.copy()
, but preserves the type and the self.graph_type attribute

edges
¶ Identical to
networkx.DiGraph.edges

graph
¶ Identical to
networkx.DiGraph.graph

in_degree
¶ Identical to
networkx.DiGraph.in_degree()

is_directed
¶ Identical to
networkx.DiGraph.is_directed

log_prob_sum
(site_filter=<function <lambda>>)[source]¶ Compute the sitewise log probabilities of the trace. Each
log_prob
has shape equal to the correspondingbatch_shape
. Eachlog_prob_sum
is a scalar. The computation oflog_prob_sum
is memoized.Returns: total log probability. Return type: torch.Tensor

nodes
¶ Identical to
networkx.DiGraph.nodes

nonreparam_stochastic_nodes
¶ Returns: a list of names of sample sites whose stochastic functions are not reparameterizable primitive distributions

observation_nodes
¶ Returns: a list of names of observe sites

param_nodes
¶ Returns: a list of names of param sites

remove_node
¶ Identical to
networkx.DiGraph.remove_node()

reparameterized_nodes
¶ Returns: a list of names of sample sites whose stochastic functions are reparameterizable primitive distributions

stochastic_nodes
¶ Returns: a list of names of sample sites

successors
¶ Identical to
networkx.DiGraph.successors()

Messengers¶
Messenger objects contain the implementations of the effects exposed by handlers. Advanced users may modify the implementations of messengers behind existing handlers or write new messengers that implement new effects and compose correctly with the rest of the library.
Messenger¶

class
Messenger
[source]¶ Bases:
object
Context manager class that modifies behavior and adds side effects to stochastic functions i.e. callables containing pyro primitive statements.
This is the base Messenger class. It implements the default behavior for all pyro primitives, so that the joint distribution induced by a stochastic function fn is identical to the joint distribution induced by
Messenger()(fn)
.Class of transformers for messages passed during inference. Most inference operations are implemented in subclasses of this.
BlockMessenger¶

class
BlockMessenger
(hide_all=True, expose_all=False, hide=None, expose=None, hide_types=None, expose_types=None)[source]¶ Bases:
pyro.poutine.messenger.Messenger
This Messenger selectively hides pyro primitive sites from the outside world. Default behavior: block everything BlockMessenger has a flexible interface that allows users to specify in several different ways which sites should be hidden or exposed.
A site is hidden if at least one of the following holds:
 msg[“name”] in hide
 msg[“type”] in hide_types
 msg[“name”] not in expose and msg[“type”] not in expose_types
 hide_all == True and hide, hide_types, and expose_types are all None
For example, suppose the stochastic function fn has two sample sites “a” and “b”. Then any poutine outside of BlockMessenger(fn, hide=[“a”]) will not be applied to site “a” and will only see site “b”:
>>> fn_inner = TraceMessenger()(fn) >>> fn_outer = TraceMessenger()(BlockMessenger(hide=["a"])(TraceMessenger()(fn))) >>> trace_inner = fn_inner.get_trace() >>> trace_outer = fn_outer.get_trace() >>> "a" in trace_inner True >>> "a" in trace_outer False >>> "b" in trace_inner True >>> "b" in trace_outer True
See the constructor for details.
Parameters:  hide_all (bool) – hide all sites
 expose_all (bool) – expose all sites normally
 hide (list) – list of site names to hide, rest will be exposed normally
 expose (list) – list of site names to expose, rest will be hidden
 hide_types (list) – list of site types to hide, rest will be exposed normally
 expose_types (list) – list of site types to expose normally, rest will be hidden
ConditionMessenger¶

class
ConditionMessenger
(data)[source]¶ Bases:
pyro.poutine.messenger.Messenger
Adds values at observe sites to condition on data and override sampling
EscapeMessenger¶

class
EscapeMessenger
(escape_fn)[source]¶ Bases:
pyro.poutine.messenger.Messenger
Messenger that does a nonlocal exit by raising a util.NonlocalExit exception
IndepMessenger¶

class
CondIndepStackFrame
[source]¶ Bases:
pyro.poutine.indep_messenger.CondIndepStackFrame

vectorized
¶


class
IndepMessenger
(name, size, dim=None)[source]¶ Bases:
pyro.poutine.messenger.Messenger
This messenger keeps track of stack of independence information declared by nested
irange
andiarange
contexts. This information is stored in acond_indep_stack
at each sample/observe site for consumption byTraceMessenger
.
LiftMessenger¶

class
LiftMessenger
(prior)[source]¶ Bases:
pyro.poutine.messenger.Messenger
Messenger which “lifts” parameters to random samples. Given a stochastic function with param calls and a prior, creates a stochastic function where all param calls are replaced by sampling from prior.
Prior should be a callable or a dict of names to callables.
ReplayMessenger¶

class
ReplayMessenger
(trace=None, sites=None)[source]¶ Bases:
pyro.poutine.messenger.Messenger
Messenger for replaying from an existing execution trace.
ScaleMessenger¶

class
ScaleMessenger
(scale)[source]¶ Bases:
pyro.poutine.messenger.Messenger
This messenger rescales the log probability score.
This is typically used for data subsampling or for stratified sampling of data (e.g. in fraud detection where negatives vastly outnumber positives).
Parameters: scale (float or torch.Tensor) – a positive scaling factor
TraceMessenger¶

class
TraceHandler
(msngr, fn)[source]¶ Bases:
pyro.poutine.messenger.Handler
Execution trace poutine.
A TraceHandler records the input and output to every pyro primitive and stores them as a site in a Trace(). This should, in theory, be sufficient information for every inference algorithm (along with the implicit computational graph in the Variables?)
We can also use this for visualization.

get_trace
(*args, **kwargs)[source]¶ Returns: data structure Return type: pyro.poutine.Trace Helper method for a very common use case. Calls this poutine and returns its trace instead of the function’s return value.

trace
¶


class
TraceMessenger
(graph_type=None, param_only=None)[source]¶ Bases:
pyro.poutine.messenger.Messenger
Execution trace messenger.
A TraceMessenger records the input and output to every pyro primitive and stores them as a site in a Trace(). This should, in theory, be sufficient information for every inference algorithm (along with the implicit computational graph in the Variables?)
We can also use this for visualization.

get_trace
()[source]¶ Returns: data structure Return type: pyro.poutine.Trace Helper method for a very common use case. Returns a shallow copy of
self.trace
.

Runtime¶

exception
NonlocalExit
(site, *args, **kwargs)[source]¶ Bases:
exceptions.Exception
Exception for exiting nonlocally from poutine execution.
Used by poutine.EscapeMessenger to return site information.

am_i_wrapped
()[source]¶ Checks whether the current computation is wrapped in a poutine. :returns: bool

apply_stack
(initial_msg)[source]¶ Execute the effect stack at a single site according to the following scheme:
 For each
Messenger
in the stack from bottom to top, executeMessenger._process_message
with the message; if the message field “stop” is True, stop; otherwise, continue  Apply default behavior (
default_process_message
) to finish remaining site execution  For each
Messenger
in the stack from top to bottom, execute_postprocess_message
to update the message and internal messenger state with the site results  If the message field “continuation” is not
None
, call it with the message
Parameters: initial_msg (dict) – the starting version of the trace site Returns: None
 For each
Utilities¶

all_escape
(trace, msg)[source]¶ Parameters:  trace – a partial trace
 msg – the message at a pyro primitive site
Returns: boolean decision value
Utility function that checks if a site is not already in a trace.
Used by EscapeMessenger to decide whether to do a nonlocal exit at a site. Subroutine for approximately integrating out variables for variance reduction.

discrete_escape
(trace, msg)[source]¶ Parameters:  trace – a partial trace
 msg – the message at a pyro primitive site
Returns: boolean decision value
Utility function that checks if a sample site is discrete and not already in a trace.
Used by EscapeMessenger to decide whether to do a nonlocal exit at a site. Subroutine for integrating out discrete variables for variance reduction.

enum_extend
(trace, msg, num_samples=None)[source]¶ Parameters:  trace – a partial trace
 msg – the message at a pyro primitive site
 num_samples – maximum number of extended traces to return.
Returns: a list of traces, copies of input trace with one extra site
Utility function to copy and extend a trace with sites based on the input site whose values are enumerated from the support of the input site’s distribution.
Used for exact inference and integrating out discrete variables.

mc_extend
(trace, msg, num_samples=None)[source]¶ Parameters:  trace – a partial trace
 msg – the message at a pyro primitive site
 num_samples – maximum number of extended traces to return.
Returns: a list of traces, copies of input trace with one extra site
Utility function to copy and extend a trace with sites based on the input site whose values are sampled from the input site’s function.
Used for Monte Carlo marginalization of individual sample sites.