Poutine (Effect handlers)

Beneath the built-in inference algorithms, Pyro has a library of composable effect handlers for creating new inference algorithms and working with probabilistic programs. Pyro’s inference algorithms are all built by applying these handlers to stochastic functions.

Handlers

Poutine is a library of composable effect handlers for recording and modifying the behavior of Pyro programs. These lower-level ingredients simplify the implementation of new inference algorithms and behavior.

Handlers can be used as higher-order functions, decorators, or context managers to modify the behavior of functions or blocks of code:

For example, consider the following Pyro program:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

We can mark sample sites as observed using condition, which returns a callable with the same input and output signatures as model:

>>> conditioned_model = poutine.condition(model, data={"z": 1.0})

We can also use handlers as decorators:

>>> @pyro.condition(data={"z": 1.0})
... def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

Or as context managers:

>>> with pyro.condition(data={"z": 1.0}):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(0., s))
...     y = z ** 2

Handlers compose freely:

>>> conditioned_model = poutine.condition(model, data={"z": 1.0})
>>> traced_model = poutine.trace(conditioned_model)

Many inference algorithms or algorithmic components can be implemented in just a few lines of code:

guide_tr = poutine.trace(guide).get_trace(...)
model_tr = poutine.trace(poutine.replay(conditioned_model, trace=guide_tr)).get_trace(...)
monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
block(fn=None, hide_fn=None, expose_fn=None, hide=None, expose=None, hide_types=None, expose_types=None)[source]

This handler selectively hides Pyro primitive sites from the outside world. Default behavior: block everything.

A site is hidden if at least one of the following holds:

  1. hide_fn(msg) is True or (not expose_fn(msg)) is True
  2. msg["name"] in hide
  3. msg["type"] in hide_types
  4. msg["name"] not in expose and msg["type"] not in expose_types
  5. hide, hide_types, and expose_types are all None

For example, suppose the stochastic function fn has two sample sites “a” and “b”. Then any effect outside of BlockMessenger(fn, hide=["a"]) will not be applied to site “a” and will only see site “b”:

>>> def fn():
...     a = pyro.sample("a", dist.Normal(0., 1.))
...     return pyro.sample("b", dist.Normal(a, 1.))
>>> fn_inner = trace(fn)
>>> fn_outer = trace(block(fn_inner, hide=["a"]))
>>> trace_inner = fn_inner.get_trace()
>>> trace_outer  = fn_outer.get_trace()
>>> "a" in trace_inner
True
>>> "a" in trace_outer
False
>>> "b" in trace_inner
True
>>> "b" in trace_outer
True
Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • hide – list of site names to hide
  • expose – list of site names to be exposed while all others hidden
  • hide_types – list of site types to be hidden
  • expose_types – list of site types to be exposed while all others hidden
Param:

hide_fn: function that takes a site and returns True to hide the site or False/None to expose it. If specified, all other parameters are ignored. Only specify one of hide_fn or expose_fn, not both.

Param:

expose_fn: function that takes a site and returns True to expose the site or False/None to hide it. If specified, all other parameters are ignored. Only specify one of hide_fn or expose_fn, not both.

Returns:

stochastic function decorated with a BlockMessenger

broadcast(fn=None)[source]

Automatically broadcasts the batch shape of the stochastic function at a sample site when inside a single or nested plate context. The existing batch_shape must be broadcastable with the size of the plate contexts installed in the cond_indep_stack.

Notice how model_automatic_broadcast below automates expanding of distribution batch shapes. This makes it easy to modularize a Pyro model as the sub-components are agnostic of the wrapping plate contexts.

>>> def model_broadcast_by_hand():
...     with IndepMessenger("batch", 100, dim=-2):
...         with IndepMessenger("components", 3, dim=-1):
...             sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5)
...                                                .expand_by(100))
...             assert sample.shape == torch.Size((100, 3))
...     return sample
>>> @poutine.broadcast
... def model_automatic_broadcast():
...     with IndepMessenger("batch", 100, dim=-2):
...         with IndepMessenger("components", 3, dim=-1):
...             sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5)))
...             assert sample.shape == torch.Size((100, 3))
...     return sample
condition(fn=None, data=None)[source]

Given a stochastic function with some sample statements and a dictionary of observations at names, change the sample statements at those names into observes with those values.

Consider the following Pyro program:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

To observe a value for site z, we can write

>>> conditioned_model = condition(model, data={"z": torch.tensor(1.)})

This is equivalent to adding obs=value as a keyword argument to pyro.sample(“z”, …) in model.

Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • data – a dict or a Trace
Returns:

stochastic function decorated with a ConditionMessenger

do(fn=None, data=None)[source]

Given a stochastic function with some sample statements and a dictionary of values at names, set the return values of those sites equal to the values and hide them from the rest of the stack as if they were hard-coded to those values by using block.

Consider the following Pyro program:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

To intervene with a value for site z, we can write

>>> intervened_model = do(model, data={"z": torch.tensor(1.)})

This is equivalent to replacing z = pyro.sample(“z”, …) with z = value.

Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • data – a dict or a Trace
Returns:

stochastic function decorated with a BlockMessenger and pyro.poutine.condition_messenger.ConditionMessenger

enum(fn=None, first_available_dim=None)[source]

Enumerates in parallel over discrete sample sites marked infer={"enumerate": "parallel"}.

Parameters:first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer.
escape(fn=None, escape_fn=None)[source]

Given a callable that contains Pyro primitive calls, evaluate escape_fn on each site, and if the result is True, raise a NonlocalExit exception that stops execution and returns the offending site.

Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • escape_fn – function that takes a partial trace and a site, and returns a boolean value to decide whether to exit at that site
Returns:

stochastic function decorated with EscapeMessenger

infer_config(fn=None, config_fn=None)[source]

Given a callable that contains Pyro primitive calls and a callable taking a trace site and returning a dictionary, updates the value of the infer kwarg at a sample site to config_fn(site).

Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • config_fn – a callable taking a site and returning an infer dict
Returns:

stochastic function decorated with InferConfigMessenger

lift(fn=None, prior=None)[source]

Given a stochastic function with param calls and a prior distribution, create a stochastic function where all param calls are replaced by sampling from prior. Prior should be a callable or a dict of names to callables.

Consider the following Pyro program:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2
>>> lifted_model = lift(model, prior={"s": dist.Exponential(0.3)})

lift makes param statements behave like sample statements using the distributions in prior. In this example, site s will now behave as if it was replaced with s = pyro.sample("s", dist.Exponential(0.3)):

>>> tr = trace(lifted_model).get_trace(0.0)
>>> tr.nodes["s"]["type"] == "sample"
True
>>> tr2 = trace(lifted_model).get_trace(0.0)
>>> bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all())
False
Parameters:
  • fn – function whose parameters will be lifted to random values
  • prior – prior function in the form of a Distribution or a dict of stochastic fns
Returns:

fn decorated with a LiftMessenger

markov(fn=None, history=1, keep=False)[source]

Markov dependency declaration.

This can be used in a variety of ways: - as a context manager - as a decorator for recursive functions - as an iterator for markov chains

Parameters:
  • history (int) – The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to pyro.plate.
  • keep (bool) – If true, frames are replayable. This is important when branching: if keep=True, neighboring branches at the same level can depend on each other; if keep=False, neighboring branches are independent (conditioned on their share”
mask(fn=None, mask=None)[source]

Given a stochastic function with some batched sample statements and masking tensor, mask out some of the sample statements elementwise.

Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • mask (torch.ByteTensor) – a {0,1}-valued masking tensor (1 includes a site, 0 excludes a site)
Returns:

stochastic function decorated with a MaskMessenger

queue(fn=None, queue=None, max_tries=None, extend_fn=None, escape_fn=None, num_samples=None)[source]

Used in sequential enumeration over discrete variables.

Given a stochastic function and a queue, return a return value from a complete trace in the queue.

Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • queue – a queue data structure like multiprocessing.Queue to hold partial traces
  • max_tries – maximum number of attempts to compute a single complete trace
  • extend_fn – function (possibly stochastic) that takes a partial trace and a site, and returns a list of extended traces
  • escape_fn – function (possibly stochastic) that takes a partial trace and a site, and returns a boolean value to decide whether to exit
  • num_samples – optional number of extended traces for extend_fn to return
Returns:

stochastic function decorated with poutine logic

replay(fn=None, trace=None, params=None)[source]

Given a callable that contains Pyro primitive calls, return a callable that runs the original, reusing the values at sites in trace at those sites in the new trace

Consider the following Pyro program:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

replay makes sample statements behave as if they had sampled the values at the corresponding sites in the trace:

>>> old_trace = trace(model).get_trace(1.0)
>>> replayed_model = replay(model, trace=old_trace)
>>> bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"])
True
Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • trace – a Trace data structure to replay against
  • params – dict of names of param sites and constrained values in fn to replay against
Returns:

a stochastic function decorated with a ReplayMessenger

scale(fn=None, scale=None)[source]

Given a stochastic function with some sample statements and a positive scale factor, scale the score of all sample and observe sites in the function.

Consider the following Pyro program:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s), obs=1.0)
...     return z ** 2

scale multiplicatively scales the log-probabilities of sample sites:

>>> scaled_model = scale(model, scale=0.5)
>>> scaled_tr = trace(scaled_model).get_trace(0.0)
>>> unscaled_tr = trace(model).get_trace(0.0)
>>> bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all())
True
Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • scale – a positive scaling factor
Returns:

stochastic function decorated with a ScaleMessenger

trace(fn=None, graph_type=None, param_only=None, strict_names=None)[source]

Return a handler that records the inputs and outputs of primitive calls and their dependencies.

Consider the following Pyro program:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

We can record its execution using trace and use the resulting data structure to compute the log-joint probability of all of the sample sites in the execution or extract all parameters.

>>> trace = trace(model).get_trace(0.0)
>>> logp = trace.log_prob_sum()
>>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • graph_type – string that specifies the kind of graph to construct
  • param_only – if true, only records params and not samples
Returns:

stochastic function decorated with a TraceMessenger

config_enumerate(guide=None, default='sequential', expand=False, num_samples=None)[source]

Configures enumeration for all relevant sites in a guide. This is mainly used in conjunction with TraceEnum_ELBO.

When configuring for exhaustive enumeration of discrete variables, this configures all sample sites whose distribution satisfies .has_enumerate_support == True. When configuring for local parallel Monte Carlo sampling via default="parallel", num_samples=n, this configures all sample sites. This does not overwrite existing annotations infer={"enumerate": ...}.

This can be used as either a function:

guide = config_enumerate(guide)

or as a decorator:

@config_enumerate
def guide1(*args, **kwargs):
    ...

@config_enumerate(default="parallel", expand=True)
def guide2(*args, **kwargs):
    ...
Parameters:
  • guide (callable) – a pyro model that will be used as a guide in SVI.
  • default (str) – Which enumerate strategy to use, one of “sequential”, “parallel”, or None.
  • expand (bool) – Whether to expand enumerated sample values. See enumerate_support() for details. This only applies to exhaustive enumeration, where num_samples=None. If num_samples is not None, then this samples will always be expanded.
  • num_samples (int or None) – if not None, use local Monte Carlo sampling rather than exhaustive enumeration. This makes sense for both continuous and discrete distributions.
Returns:

an annotated guide

Return type:

callable

Trace

class Trace(*args, **kwargs)[source]

Bases: networkx.classes.digraph.DiGraph

Execution trace data structure built on top of networkx.DiGraph.

An execution trace of a Pyro program is a record of every call to pyro.sample() and pyro.param() in a single execution of that program. Traces are directed graphs whose nodes represent primitive calls or input/output, and whose edges represent conditional dependence relationships between those primitive calls. They are created and populated by poutine.trace.

Each node (or site) in a trace contains the name, input and output value of the site, as well as additional metadata added by inference algorithms or user annotation. In the case of pyro.sample, the trace also includes the stochastic function at the site, and any observed data added by users.

Consider the following Pyro program:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

We can record its execution using pyro.poutine.trace and use the resulting data structure to compute the log-joint probability of all of the sample sites in the execution or extract all parameters.

>>> trace = pyro.poutine.trace(model).get_trace(0.0)
>>> logp = trace.log_prob_sum()
>>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]

We can also inspect or manipulate individual nodes in the trace. trace.nodes contains a collections.OrderedDict of site names and metadata corresponding to x, s, z, and the return value:

>>> list(name for name in trace.nodes.keys())  
["_INPUT", "s", "z", "_RETURN"]

As in networkx.DiGraph, values of trace.nodes are dictionaries of node metadata:

>>> trace.nodes["z"]  
{'type': 'sample', 'name': 'z', 'is_observed': False,
 'fn': Normal(), 'value': tensor(0.6480), 'args': (), 'kwargs': {},
 'infer': {}, 'scale': 1.0, 'cond_indep_stack': (),
 'done': True, 'stop': False, 'continuation': None}

'infer' is a dictionary of user- or algorithm-specified metadata. 'args' and 'kwargs' are the arguments passed via pyro.sample to fn.__call__ or fn.log_prob. 'scale' is used to scale the log-probability of the site when computing the log-joint. 'cond_indep_stack' contains data structures corresponding to pyro.plate contexts appearing in the execution. 'done', 'stop', and 'continuation' are only used by Pyro’s internals.

add_node(site_name, *args, **kwargs)[source]
Parameters:site_name (string) – the name of the site to be added

Adds a site to the trace.

Identical to networkx.DiGraph.add_node() but raises an error when attempting to add a duplicate node instead of silently overwriting.

compute_log_prob(site_filter=<function <lambda>>)[source]

Compute the site-wise log probabilities of the trace. Each log_prob has shape equal to the corresponding batch_shape. Each log_prob_sum is a scalar. Both computations are memoized.

compute_score_parts()[source]

Compute the batched local score parts at each site of the trace. Each log_prob has shape equal to the corresponding batch_shape. Each log_prob_sum is a scalar. All computations are memoized.

copy()[source]

Makes a shallow copy of self with nodes and edges preserved. Identical to networkx.DiGraph.copy(), but preserves the type and the self.graph_type attribute

format_shapes(title='Trace Shapes:', last_site=None)[source]

Returns a string showing a table of the shapes of all sites in the trace.

iter_stochastic_nodes()[source]
Returns:an iterator over stochastic nodes in the trace.
log_prob_sum(site_filter=<function <lambda>>)[source]

Compute the site-wise log probabilities of the trace. Each log_prob has shape equal to the corresponding batch_shape. Each log_prob_sum is a scalar. The computation of log_prob_sum is memoized.

Returns:total log probability.
Return type:torch.Tensor
node_dict_factory

alias of collections.OrderedDict

nonreparam_stochastic_nodes
Returns:a list of names of sample sites whose stochastic functions are not reparameterizable primitive distributions
observation_nodes
Returns:a list of names of observe sites
pack_tensors(plate_to_symbol=None)[source]

Computes packed representations of tensors in the trace. This should be called after compute_log_prob() or compute_score_parts().

param_nodes
Returns:a list of names of param sites
reparameterized_nodes
Returns:a list of names of sample sites whose stochastic functions are reparameterizable primitive distributions
stochastic_nodes
Returns:a list of names of sample sites
symbolize_dims(plate_to_symbol=None)[source]

Assign unique symbols to all tensor dimensions.

Messengers

Messenger objects contain the implementations of the effects exposed by handlers. Advanced users may modify the implementations of messengers behind existing handlers or write new messengers that implement new effects and compose correctly with the rest of the library.

Messenger

class Messenger[source]

Bases: object

Context manager class that modifies behavior and adds side effects to stochastic functions i.e. callables containing Pyro primitive statements.

This is the base Messenger class. It implements the default behavior for all Pyro primitives, so that the joint distribution induced by a stochastic function fn is identical to the joint distribution induced by Messenger()(fn).

Class of transformers for messages passed during inference. Most inference operations are implemented in subclasses of this.

classmethod register(fn=None, type=None, post=None)[source]
Parameters:
  • fn – function implementing operation
  • type (str) – name of the operation (also passed to :func:~`pyro.poutine.runtime.effectful`)
  • post (bool) – if True, use this operation as postprocess

Dynamically add operations to an effect. Useful for generating wrappers for libraries.

Example:

@SomeMessengerClass.register
def some_function(msg)
    ...do_something...
    return msg
classmethod unregister(fn=None, type=None)[source]
Parameters:
  • fn – function implementing operation
  • type (str) – name of the operation (also passed to :func:~`pyro.poutine.runtime.effectful`)

Dynamically remove operations from an effect. Useful for removing wrappers from libraries.

Example:

SomeMessengerClass.unregister(some_function, "name")

BlockMessenger

class BlockMessenger(hide_fn=None, expose_fn=None, hide_all=True, expose_all=False, hide=None, expose=None, hide_types=None, expose_types=None)[source]

Bases: pyro.poutine.messenger.Messenger

This Messenger selectively hides Pyro primitive sites from the outside world. Default behavior: block everything. BlockMessenger has a flexible interface that allows users to specify in several different ways which sites should be hidden or exposed.

A site is hidden if at least one of the following holds:

  1. hide_fn(msg) is True or (not expose_fn(msg)) is True
  2. msg["name"] in hide
  3. msg["type"] in hide_types
  4. msg["name"] not in expose and msg["type"] not in expose_types
  5. hide, hide_types, and expose_types are all None

For example, suppose the stochastic function fn has two sample sites “a” and “b”. Then any poutine outside of BlockMessenger(fn, hide=[“a”]) will not be applied to site “a” and will only see site “b”:

>>> def fn():
...     a = pyro.sample("a", dist.Normal(0., 1.))
...     return pyro.sample("b", dist.Normal(a, 1.))
>>> fn_inner = TraceMessenger()(fn)
>>> fn_outer = TraceMessenger()(BlockMessenger(hide=["a"])(TraceMessenger()(fn)))
>>> trace_inner = fn_inner.get_trace()
>>> trace_outer  = fn_outer.get_trace()
>>> "a" in trace_inner
True
>>> "a" in trace_outer
False
>>> "b" in trace_inner
True
>>> "b" in trace_outer
True

See the constructor for details.

Param:

hide_fn: function that takes a site and returns True to hide the site or False/None to expose it. If specified, all other parameters are ignored. Only specify one of hide_fn or expose_fn, not both.

Param:

expose_fn: function that takes a site and returns True to expose the site or False/None to hide it. If specified, all other parameters are ignored. Only specify one of hide_fn or expose_fn, not both.

Parameters:
  • hide_all (bool) – hide all sites
  • expose_all (bool) – expose all sites normally
  • hide (list) – list of site names to hide, rest will be exposed normally
  • expose (list) – list of site names to expose, rest will be hidden
  • hide_types (list) – list of site types to hide, rest will be exposed normally
  • expose_types (list) – list of site types to expose normally, rest will be hidden

BroadcastMessenger

class BroadcastMessenger[source]

Bases: pyro.poutine.messenger.Messenger

BroadcastMessenger automatically broadcasts the batch shape of the stochastic function at a sample site when inside a single or nested plate context. The existing batch_shape must be broadcastable with the size of the plate contexts installed in the cond_indep_stack.

ConditionMessenger

class ConditionMessenger(data)[source]

Bases: pyro.poutine.messenger.Messenger

Adds values at observe sites to condition on data and override sampling

EscapeMessenger

class EscapeMessenger(escape_fn)[source]

Bases: pyro.poutine.messenger.Messenger

Messenger that does a nonlocal exit by raising a util.NonlocalExit exception

IndepMessenger

class CondIndepStackFrame[source]

Bases: pyro.poutine.indep_messenger.CondIndepStackFrame

vectorized
class IndepMessenger(name=None, size=None, dim=None, device=None)[source]

Bases: pyro.poutine.messenger.Messenger

This messenger keeps track of stack of independence information declared by nested irange and plate contexts. This information is stored in a cond_indep_stack at each sample/observe site for consumption by TraceMessenger.

Example:

x_axis = IndepMessenger('outer', 320, dim=-1)
y_axis = IndepMessenger('inner', 200, dim=-2)
with x_axis:
    x_noise = sample("x_noise", dist.Normal(loc, scale).expand_by([320]))
with y_axis:
    y_noise = sample("y_noise", dist.Normal(loc, scale).expand_by([200, 1]))
with x_axis, y_axis:
    xy_noise = sample("xy_noise", dist.Normal(loc, scale).expand_by([200, 320]))
indices
next_context()[source]

Increments the counter.

LiftMessenger

class LiftMessenger(prior)[source]

Bases: pyro.poutine.messenger.Messenger

Messenger which “lifts” parameters to random samples. Given a stochastic function with param calls and a prior, creates a stochastic function where all param calls are replaced by sampling from prior.

Prior should be a callable or a dict of names to callables.

ReplayMessenger

class ReplayMessenger(trace=None, params=None)[source]

Bases: pyro.poutine.messenger.Messenger

Messenger for replaying from an existing execution trace.

ScaleMessenger

class ScaleMessenger(scale)[source]

Bases: pyro.poutine.messenger.Messenger

This messenger rescales the log probability score.

This is typically used for data subsampling or for stratified sampling of data (e.g. in fraud detection where negatives vastly outnumber positives).

Parameters:scale (float or torch.Tensor) – a positive scaling factor

TraceMessenger

class TraceHandler(msngr, fn)[source]

Bases: object

Execution trace poutine.

A TraceHandler records the input and output to every Pyro primitive and stores them as a site in a Trace(). This should, in theory, be sufficient information for every inference algorithm (along with the implicit computational graph in the Variables?)

We can also use this for visualization.

get_trace(*args, **kwargs)[source]
Returns:data structure
Return type:pyro.poutine.Trace

Helper method for a very common use case. Calls this poutine and returns its trace instead of the function’s return value.

trace
class TraceMessenger(graph_type=None, param_only=None, strict_names=None)[source]

Bases: pyro.poutine.messenger.Messenger

Execution trace messenger.

A TraceMessenger records the input and output to every Pyro primitive and stores them as a site in a Trace(). This should, in theory, be sufficient information for every inference algorithm (along with the implicit computational graph in the Variables?)

We can also use this for visualization.

get_trace()[source]
Returns:data structure
Return type:pyro.poutine.Trace

Helper method for a very common use case. Returns a shallow copy of self.trace.

identify_dense_edges(trace)[source]

Modifies a trace in-place by adding all edges based on the cond_indep_stack information stored at each site.

Runtime

exception NonlocalExit(site, *args, **kwargs)[source]

Bases: exceptions.Exception

Exception for exiting nonlocally from poutine execution.

Used by poutine.EscapeMessenger to return site information.

reset_stack()[source]

Reset the state of the frames remaining in the stack. Necessary for multiple re-executions in poutine.queue.

am_i_wrapped()[source]

Checks whether the current computation is wrapped in a poutine. :returns: bool

apply_stack(initial_msg)[source]

Execute the effect stack at a single site according to the following scheme:

  1. For each Messenger in the stack from bottom to top, execute Messenger._process_message with the message; if the message field “stop” is True, stop; otherwise, continue
  2. Apply default behavior (default_process_message) to finish remaining site execution
  3. For each Messenger in the stack from top to bottom, execute _postprocess_message to update the message and internal messenger state with the site results
  4. If the message field “continuation” is not None, call it with the message
Parameters:initial_msg (dict) – the starting version of the trace site
Returns:None
default_process_message(msg)[source]

Default method for processing messages in inference. :param msg: a message to be processed :returns: None

effectful(fn=None, type=None)[source]
Parameters:
  • fn – function or callable that performs an effectful computation
  • type (str) – the type label of the operation, e.g. “sample”

Wrapper for calling :func:~`pyro.poutine.runtime.apply_stack` to apply any active effects.

Utilities

all_escape(trace, msg)[source]
Parameters:
  • trace – a partial trace
  • msg – the message at a Pyro primitive site
Returns:

boolean decision value

Utility function that checks if a site is not already in a trace.

Used by EscapeMessenger to decide whether to do a nonlocal exit at a site. Subroutine for approximately integrating out variables for variance reduction.

discrete_escape(trace, msg)[source]
Parameters:
  • trace – a partial trace
  • msg – the message at a Pyro primitive site
Returns:

boolean decision value

Utility function that checks if a sample site is discrete and not already in a trace.

Used by EscapeMessenger to decide whether to do a nonlocal exit at a site. Subroutine for integrating out discrete variables for variance reduction.

enable_validation(is_validate)[source]
enum_extend(trace, msg, num_samples=None)[source]
Parameters:
  • trace – a partial trace
  • msg – the message at a Pyro primitive site
  • num_samples – maximum number of extended traces to return.
Returns:

a list of traces, copies of input trace with one extra site

Utility function to copy and extend a trace with sites based on the input site whose values are enumerated from the support of the input site’s distribution.

Used for exact inference and integrating out discrete variables.

is_validation_enabled()[source]
mc_extend(trace, msg, num_samples=None)[source]
Parameters:
  • trace – a partial trace
  • msg – the message at a Pyro primitive site
  • num_samples – maximum number of extended traces to return.
Returns:

a list of traces, copies of input trace with one extra site

Utility function to copy and extend a trace with sites based on the input site whose values are sampled from the input site’s function.

Used for Monte Carlo marginalization of individual sample sites.

prune_subsample_sites(trace)[source]

Copies and removes all subsample sites from a trace.

site_is_subsample(site)[source]

Determines whether a trace site originated from a subsample statement inside an plate.