# Poutines (Pyro Coroutines)¶

## Poutine¶

class Poutine(fn)[source]

Bases: object

Context manager class that modifies behavior and adds side effects to stochastic functions i.e. callables containing pyro primitive statements.

See the Poutine execution model writeup in the documentation for a description of the entire Poutine system.

This is the base Poutine class. It implements the default behavior for all pyro primitives, so that the joint distribution induced by a stochastic function fn is identical to the joint distribution induced by Poutine(fn).

## BlockPoutine¶

class BlockPoutine(fn, hide_all=True, expose_all=False, hide=None, expose=None, hide_types=None, expose_types=None)[source]

This Poutine selectively hides pyro primitive sites from the outside world.

For example, suppose the stochastic function fn has two sample sites “a” and “b”. Then any poutine outside of BlockPoutine(fn, hide=[“a”]) will not be applied to site “a” and will only see site “b”:

>>> fn_inner = TracePoutine(fn)
>>> fn_outer = TracePoutine(BlockPoutine(TracePoutine(fn), hide=["a"]))
>>> trace_inner = fn_inner.get_trace()
>>> trace_outer  = fn_outer.get_trace()
>>> "a" in trace_inner
True
>>> "a" in trace_outer
False
>>> "b" in trace_inner
True
>>> "b" in trace_outer
True


BlockPoutine has a flexible interface that allows users to specify in several different ways which sites should be hidden or exposed. See the constructor for details.

## ConditionPoutine¶

class ConditionPoutine(fn, data)[source]

Adds values at observe sites to condition on data and override sampling

## EscapePoutine¶

class EscapePoutine(fn, escape_fn)[source]

Poutine that does a nonlocal exit by raising a util.NonlocalExit exception

## IndepPoutine¶

class CondIndepStackFrame(name, counter, vectorized)

Bases: tuple

counter

Alias for field number 1

name

Alias for field number 0

vectorized

Alias for field number 2

class IndepPoutine(fn, name, vectorized)[source]

This poutine keeps track of stack of independence information declared by nested irange and iarange contexts. This information is stored in a cond_indep_stack at each sample/observe site for consumption by TracePoutine.

## LiftPoutine¶

class LiftPoutine(fn, prior)[source]

Poutine which “lifts” parameters to random samples. Given a stochastic function with param calls and a prior, creates a stochastic function where all param calls are replaced by sampling from prior.

Prior should be a callable or a dict of names to callables.

## ReplayPoutine¶

class ReplayPoutine(fn, guide_trace, sites=None)[source]

Poutine for replaying from an existing execution trace.

## ScalePoutine¶

class ScalePoutine(fn, scale)[source]

This poutine rescales the log probability score.

This is typically used for data subsampling or for stratified sampling of data (e.g. in fraud detection where negatives vastly outnumber positives).

Parameters: fn (callable or None) – an optional function to be scaled scale (float or torch.autograd.Variable) – a positive scaling factor

## Trace¶

class Trace(*args, **kwargs)[source]

Bases: networkx.classes.digraph.DiGraph

Execution trace data structure

add_node(site_name, *args, **kwargs)[source]
Parameters: site_name (string) – the name of the site to be added

Adds a site to the trace.

Identical to super(Trace, self).add_node, but raises an error when attempting to add a duplicate node instead of silently overwriting.

batch_log_pdf(site_filter=<function <lambda>>)[source]

Compute the batched local and overall log-probabilities of the trace.

The local computation is memoized, and also stores the local .log_pdf().

compute_batch_log_pdf(site_filter=<function <lambda>>)[source]

Compute the batched local log-probabilities at each site of the trace.

The local computation is memoized, and also stores the local .log_pdf().

copy()[source]

Makes a shallow copy of self with nodes and edges preserved. Identical to super(Trace, self).copy(), but preserves the type and the self.graph_type attribute

log_pdf(site_filter=<function <lambda>>)[source]

Compute the local and overall log-probabilities of the trace.

The local computation is memoized.

node_dict_factory

alias of OrderedDict

nonreparam_stochastic_nodes

Gets a list of names of sample sites whose stochastic functions are not reparameterizable primitive distributions

observation_nodes

Gets a list of names of observe sites

reparameterized_nodes

Gets a list of names of sample sites whose stochastic functions are reparameterizable primitive distributions

stochastic_nodes

Gets a list of names of sample sites

## TracePoutine¶

class TracePoutine(fn, graph_type=None)[source]

Execution trace poutine.

A TracePoutine records the input and output to every pyro primitive and stores them as a site in a Trace(). This should, in theory, be sufficient information for every inference algorithm (along with the implicit computational graph in the Variables?)

We can also use this for visualization.

get_trace(*args, **kwargs)[source]
Returns: data structure pyro.poutine.Trace

Helper method for a very common use case. Calls this poutine and returns its trace instead of the function’s return value.

get_vectorized_map_data_info(trace)[source]

This determines whether the vectorized map_datas are rao-blackwellizable by TraceGraph_ELBO. This also gathers information to be consumed by downstream by TraceGraph_ELBO.

identify_dense_edges(trace)[source]

Modifies a trace in-place by adding all edges based on the cond_indep_stack information stored at each site.