Poutines (Pyro Coroutines)


class Poutine(fn)[source]

Bases: object

Context manager class that modifies behavior and adds side effects to stochastic functions i.e. callables containing pyro primitive statements.

See the Poutine execution model writeup in the documentation for a description of the entire Poutine system.

This is the base Poutine class. It implements the default behavior for all pyro primitives, so that the joint distribution induced by a stochastic function fn is identical to the joint distribution induced by Poutine(fn).


class BlockPoutine(fn, hide_all=True, expose_all=False, hide=None, expose=None, hide_types=None, expose_types=None)[source]

Bases: pyro.poutine.poutine.Poutine

This Poutine selectively hides pyro primitive sites from the outside world.

For example, suppose the stochastic function fn has two sample sites “a” and “b”. Then any poutine outside of BlockPoutine(fn, hide=[“a”]) will not be applied to site “a” and will only see site “b”:

>>> fn_inner = TracePoutine(fn)
>>> fn_outer = TracePoutine(BlockPoutine(TracePoutine(fn), hide=["a"]))
>>> trace_inner = fn_inner.get_trace()
>>> trace_outer  = fn_outer.get_trace()
>>> "a" in trace_inner
>>> "a" in trace_outer
>>> "b" in trace_inner
>>> "b" in trace_outer

BlockPoutine has a flexible interface that allows users to specify in several different ways which sites should be hidden or exposed. See the constructor for details.


class ConditionPoutine(fn, data)[source]

Bases: pyro.poutine.poutine.Poutine

Adds values at observe sites to condition on data and override sampling


class EscapePoutine(fn, escape_fn)[source]

Bases: pyro.poutine.poutine.Poutine

Poutine that does a nonlocal exit by raising a util.NonlocalExit exception


class CondIndepStackFrame(name, counter, vectorized)

Bases: tuple


Alias for field number 1


Alias for field number 0


Alias for field number 2

class IndepPoutine(fn, name, vectorized)[source]

Bases: pyro.poutine.poutine.Poutine

This poutine keeps track of stack of independence information declared by nested irange and iarange contexts. This information is stored in a cond_indep_stack at each sample/observe site for consumption by TracePoutine.


class LiftPoutine(fn, prior)[source]

Bases: pyro.poutine.poutine.Poutine

Poutine which “lifts” parameters to random samples. Given a stochastic function with param calls and a prior, creates a stochastic function where all param calls are replaced by sampling from prior.

Prior should be a callable or a dict of names to callables.


class ReplayPoutine(fn, guide_trace, sites=None)[source]

Bases: pyro.poutine.poutine.Poutine

Poutine for replaying from an existing execution trace.


class ScalePoutine(fn, scale)[source]

Bases: pyro.poutine.poutine.Poutine

This poutine rescales the log probability score.

This is typically used for data subsampling or for stratified sampling of data (e.g. in fraud detection where negatives vastly outnumber positives).

  • fn (callable or None) – an optional function to be scaled
  • scale (float or torch.autograd.Variable) – a positive scaling factor


class Trace(*args, **kwargs)[source]

Bases: networkx.classes.digraph.DiGraph

Execution trace data structure

add_node(site_name, *args, **kwargs)[source]
Parameters:site_name (string) – the name of the site to be added

Adds a site to the trace.

Identical to super(Trace, self).add_node, but raises an error when attempting to add a duplicate node instead of silently overwriting.

batch_log_pdf(site_filter=<function <lambda>>)[source]

Compute the batched local and overall log-probabilities of the trace.

The local computation is memoized, and also stores the local .log_pdf().

compute_batch_log_pdf(site_filter=<function <lambda>>)[source]

Compute the batched local log-probabilities at each site of the trace.

The local computation is memoized, and also stores the local .log_pdf().


Makes a shallow copy of self with nodes and edges preserved. Identical to super(Trace, self).copy(), but preserves the type and the self.graph_type attribute

log_pdf(site_filter=<function <lambda>>)[source]

Compute the local and overall log-probabilities of the trace.

The local computation is memoized.

Returns:total log probability.
Return type:torch.autograd.Variable

alias of OrderedDict


Gets a list of names of sample sites whose stochastic functions are not reparameterizable primitive distributions


Gets a list of names of observe sites


Gets a list of names of sample sites whose stochastic functions are reparameterizable primitive distributions


Gets a list of names of sample sites


class TracePoutine(fn, graph_type=None)[source]

Bases: pyro.poutine.poutine.Poutine

Execution trace poutine.

A TracePoutine records the input and output to every pyro primitive and stores them as a site in a Trace(). This should, in theory, be sufficient information for every inference algorithm (along with the implicit computational graph in the Variables?)

We can also use this for visualization.

get_trace(*args, **kwargs)[source]
Returns:data structure
Return type:pyro.poutine.Trace

Helper method for a very common use case. Calls this poutine and returns its trace instead of the function’s return value.


This determines whether the vectorized map_datas are rao-blackwellizable by TraceGraph_ELBO. This also gathers information to be consumed by downstream by TraceGraph_ELBO.


Modifies a trace in-place by adding all edges based on the cond_indep_stack information stored at each site.