Automatic Name Generation

The pyro.contrib.autoname module provides tools for automatically generating unique, semantically meaningful names for sample sites.

scope(fn=None, prefix=None, inner=None)[source]
Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • prefix – a string to prepend to sample names (optional if fn is provided)
  • inner – switch to determine where duplicate name counters appear
Returns:

fn decorated with a ScopeMessenger

scope prepends a prefix followed by a / to the name at a Pyro sample site. It works much like TensorFlow’s name_scope and variable_scope, and can be used as a context manager, a decorator, or a higher-order function.

scope is very useful for aligning compositional models with guides or data.

Example:

>>> @scope(prefix="a")
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

Example:

>>> def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

Scopes compose as expected, with outer scopes appearing before inner scopes in names:

>>> @scope(prefix="b")
... def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "b/a/x" in poutine.trace(model).get_trace()

When used as a decorator or higher-order function, scope will use the name of the input function as the prefix if no user-specified prefix is provided.

Example:

>>> @scope
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "model/x" in poutine.trace(model).get_trace()
name_count(fn=None)[source]

name_count is a very simple autonaming scheme that simply appends a suffix “__” plus a counter to any name that appears multiple tims in an execution. Only duplicate instances of a name get a suffix; the first instance is not modified.

Example:

>>> @name_count
... def model():
...     for i in range(3):
...         pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "x" in poutine.trace(model).get_trace()
>>> assert "x__1" in poutine.trace(model).get_trace()
>>> assert "x__2" in poutine.trace(model).get_trace()

name_count also composes with scope() by adding a suffix to duplicate scope entrances:

Example:

>>> @name_count
... def model():
...     for i in range(3):
...         with pyro.contrib.autoname.scope(prefix="a"):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a__1/x" in poutine.trace(model).get_trace()
>>> assert "a__2/x" in poutine.trace(model).get_trace()

Example:

>>> @name_count
... def model():
...     with pyro.contrib.autoname.scope(prefix="a"):
...         for i in range(3):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a/x__1" in poutine.trace(model).get_trace()
>>> assert "a/x__2" in poutine.trace(model).get_trace()

Named Data Structures

The pyro.contrib.named module is a thin syntactic layer on top of Pyro. It allows Pyro models to be written to look like programs with operating on Python data structures like latent.x.sample_(...), rather than programs with string-labeled statements like x = pyro.sample("x", ...).

This module provides three container data structures named.Object, named.List, and named.Dict. These data structures are intended to be nested in each other. Together they track the address of each piece of data in each data structure, so that this address can be used as a Pyro site. For example:

>>> state = named.Object("state")
>>> print(str(state))
state

>>> z = state.x.y.z  # z is just a placeholder.
>>> print(str(z))
state.x.y.z

>>> state.xs = named.List()  # Create a contained list.
>>> x0 = state.xs.add()
>>> print(str(x0))
state.xs[0]

>>> state.ys = named.Dict()
>>> foo = state.ys['foo']
>>> print(str(foo))
state.ys['foo']

These addresses can now be used inside sample, observe and param statements. These named data structures even provide in-place methods that alias Pyro statements. For example:

>>> state = named.Object("state")
>>> loc = state.loc.param_(torch.zeros(1, requires_grad=True))
>>> scale = state.scale.param_(torch.ones(1, requires_grad=True))
>>> z = state.z.sample_(dist.Normal(loc, scale))
>>> obs = state.x.sample_(dist.Normal(loc, scale), obs=z)

For deeper examples of how these can be used in model code, see the Tree Data and Mixture examples.

Authors: Fritz Obermeyer, Alexander Rush

class Object(name)[source]

Bases: object

Object to hold immutable latent state.

This object can serve either as a container for nested latent state or as a placeholder to be replaced by a tensor via a named.sample, named.observe, or named.param statement. When used as a placeholder, Object objects take the place of strings in normal pyro.sample statements.

Parameters:name (str) – The name of the object.

Example:

state = named.Object("state")
state.x = 0
state.ys = named.List()
state.zs = named.Dict()
state.a.b.c.d.e.f.g = 0  # Creates a chain of named.Objects.

Warning

This data structure is write-once: data may be added but may not be mutated or removed. Trying to mutate this data structure may result in silent errors.

sample_(fn, *args, **kwargs)

Calls the stochastic function fn with additional side-effects depending on name and the enclosing context (e.g. an inference algorithm). See Intro I and Intro II for a discussion.

Parameters:
  • name – name of sample
  • fn – distribution class or function
  • obs – observed datum (optional; should only be used in context of inference) optionally specified in kwargs
  • infer (dict) – Optional dictionary of inference parameters specified in kwargs. See inference documentation for details.
Returns:

sample

param_(*args, **kwargs)

Saves the variable as a parameter in the param store. To interact with the param store or write to disk, see Parameters.

Parameters:
  • name (str) – name of parameter
  • init_tensor (torch.Tensor or callable) – initial tensor or lazy callable that returns a tensor. For large tensors, it may be cheaper to write e.g. lambda: torch.randn(100000), which will only be evaluated on the initial statement.
  • constraint (torch.distributions.constraints.Constraint) – torch constraint, defaults to constraints.real.
  • event_dim (int) – (optional) number of rightmost dimensions unrelated to baching. Dimension to the left of this will be considered batch dimensions; if the param statement is inside a subsampled plate, then corresponding batch dimensions of the parameter will be correspondingly subsampled. If unspecified, all dimensions will be considered event dims and no subsampling will be performed.
Returns:

parameter

Return type:

torch.Tensor

class List(name=None)[source]

Bases: list

List-like object to hold immutable latent state.

This must either be given a name when constructed:

latent = named.List("root")

or must be immediately stored in a named.Object:

latent = named.Object("root")
latent.xs = named.List()  # Must be bound to a Object before use.

Warning

This data structure is write-once: data may be added but may not be mutated or removed. Trying to mutate this data structure may result in silent errors.

add()[source]

Append one new named.Object.

Returns:a new latent object at the end
Return type:named.Object
class Dict(name=None)[source]

Bases: dict

Dict-like object to hold immutable latent state.

This must either be given a name when constructed:

latent = named.Dict("root")

or must be immediately stored in a named.Object:

latent = named.Object("root")
latent.xs = named.Dict()  # Must be bound to a Object before use.

Warning

This data structure is write-once: data may be added but may not be mutated or removed. Trying to mutate this data structure may result in silent errors.

Scoping

pyro.contrib.autoname.scoping contains the implementation of pyro.contrib.autoname.scope(), a tool for automatically appending a semantically meaningful prefix to names of sample sites.

class NameCountMessenger[source]

Bases: pyro.poutine.messenger.Messenger

NameCountMessenger is the implementation of pyro.contrib.autoname.name_count()

class ScopeMessenger(prefix=None, inner=None)[source]

Bases: pyro.poutine.messenger.Messenger

ScopeMessenger is the implementation of pyro.contrib.autoname.scope()

scope(fn=None, prefix=None, inner=None)[source]
Parameters:
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • prefix – a string to prepend to sample names (optional if fn is provided)
  • inner – switch to determine where duplicate name counters appear
Returns:

fn decorated with a ScopeMessenger

scope prepends a prefix followed by a / to the name at a Pyro sample site. It works much like TensorFlow’s name_scope and variable_scope, and can be used as a context manager, a decorator, or a higher-order function.

scope is very useful for aligning compositional models with guides or data.

Example:

>>> @scope(prefix="a")
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

Example:

>>> def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

Scopes compose as expected, with outer scopes appearing before inner scopes in names:

>>> @scope(prefix="b")
... def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "b/a/x" in poutine.trace(model).get_trace()

When used as a decorator or higher-order function, scope will use the name of the input function as the prefix if no user-specified prefix is provided.

Example:

>>> @scope
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "model/x" in poutine.trace(model).get_trace()
name_count(fn=None)[source]

name_count is a very simple autonaming scheme that simply appends a suffix “__” plus a counter to any name that appears multiple tims in an execution. Only duplicate instances of a name get a suffix; the first instance is not modified.

Example:

>>> @name_count
... def model():
...     for i in range(3):
...         pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "x" in poutine.trace(model).get_trace()
>>> assert "x__1" in poutine.trace(model).get_trace()
>>> assert "x__2" in poutine.trace(model).get_trace()

name_count also composes with scope() by adding a suffix to duplicate scope entrances:

Example:

>>> @name_count
... def model():
...     for i in range(3):
...         with pyro.contrib.autoname.scope(prefix="a"):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a__1/x" in poutine.trace(model).get_trace()
>>> assert "a__2/x" in poutine.trace(model).get_trace()

Example:

>>> @name_count
... def model():
...     with pyro.contrib.autoname.scope(prefix="a"):
...         for i in range(3):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a/x__1" in poutine.trace(model).get_trace()
>>> assert "a/x__2" in poutine.trace(model).get_trace()