Automatic Name Generation

The pyro.contrib.autoname module provides tools for automatically generating unique, semantically meaningful names for sample sites.

scope(fn=None, prefix=None, inner=None)[source]
Parameters
  • fn – a stochastic function (callable containing Pyro primitive calls)

  • prefix – a string to prepend to sample names (optional if fn is provided)

  • inner – switch to determine where duplicate name counters appear

Returns

fn decorated with a ScopeMessenger

scope prepends a prefix followed by a / to the name at a Pyro sample site. It works much like TensorFlow’s name_scope and variable_scope, and can be used as a context manager, a decorator, or a higher-order function.

scope is very useful for aligning compositional models with guides or data.

Example:

>>> @scope(prefix="a")
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

Example:

>>> def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

Scopes compose as expected, with outer scopes appearing before inner scopes in names:

>>> @scope(prefix="b")
... def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "b/a/x" in poutine.trace(model).get_trace()

When used as a decorator or higher-order function, scope will use the name of the input function as the prefix if no user-specified prefix is provided.

Example:

>>> @scope
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "model/x" in poutine.trace(model).get_trace()
name_count(fn=None)[source]

name_count is a very simple autonaming scheme that simply appends a suffix “__” plus a counter to any name that appears multiple tims in an execution. Only duplicate instances of a name get a suffix; the first instance is not modified.

Example:

>>> @name_count
... def model():
...     for i in range(3):
...         pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "x" in poutine.trace(model).get_trace()
>>> assert "x__1" in poutine.trace(model).get_trace()
>>> assert "x__2" in poutine.trace(model).get_trace()

name_count also composes with scope() by adding a suffix to duplicate scope entrances:

Example:

>>> @name_count
... def model():
...     for i in range(3):
...         with pyro.contrib.autoname.scope(prefix="a"):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a__1/x" in poutine.trace(model).get_trace()
>>> assert "a__2/x" in poutine.trace(model).get_trace()

Example:

>>> @name_count
... def model():
...     with pyro.contrib.autoname.scope(prefix="a"):
...         for i in range(3):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a/x__1" in poutine.trace(model).get_trace()
>>> assert "a/x__2" in poutine.trace(model).get_trace()
autoname(fn=None, name=None)[source]

Convenient wrapper of AutonameMessenger

Assign unique names to random variables.

  1. For a new varialbe use its declared name if given, otherwise use the distribution name:

    sample("x", dist.Bernoulli ... )  # -> x
    sample(dist.Bernoulli ... )  # -> Bernoulli
    
  2. For repeated variables names append the counter as a suffix:

    sample(dist.Bernoulli ... )  # -> Bernoulli
    sample(dist.Bernoulli ... )  # -> Bernoulli1
    sample(dist.Bernoulli ... )  # -> Bernoulli2
    
  3. Functions and iterators can be used as a name scope:

    @autoname
    def f1():
        sample(dist.Bernoulli ... )
    
    @autoname
    def f2():
        f1()  # -> f2/f1/Bernoulli
        f1()  # -> f2/f1__1/Bernoulli
        sample(dist.Bernoulli ... )  # -> f2/Bernoulli
    
    @autoname(name="model")
    def f3():
        for i in autoname(range(3), name="time"):
            # model/time/Bernoulli .. model/time__1/Bernoulli .. model/time__2/Bernoulli
            sample(dist.Bernoulli ... )
            # model/time/f1/Bernoulli .. model/time__1/f1/Bernoulli .. model/time__2/f1/Bernoulli
            f1()
    
  4. Or scopes can be added using the with statement:

    def f4():
        with autoname(name="prefix"):
            f1()  # -> prefix/f1/Bernoulli
            f1()  # -> prefix/f1__1/Bernoulli
            sample(dist.Bernoulli ... )  # -> prefix/Bernoulli
    
sample(*args)[source]
sample(name: str, fn, *args, **kwargs)
sample(fn: pyro.distributions.distribution.Distribution, *args, **kwargs)

Named Data Structures

The pyro.contrib.named module is a thin syntactic layer on top of Pyro. It allows Pyro models to be written to look like programs with operating on Python data structures like latent.x.sample_(...), rather than programs with string-labeled statements like x = pyro.sample("x", ...).

This module provides three container data structures named.Object, named.List, and named.Dict. These data structures are intended to be nested in each other. Together they track the address of each piece of data in each data structure, so that this address can be used as a Pyro site. For example:

>>> state = named.Object("state")
>>> print(str(state))
state

>>> z = state.x.y.z  # z is just a placeholder.
>>> print(str(z))
state.x.y.z

>>> state.xs = named.List()  # Create a contained list.
>>> x0 = state.xs.add()
>>> print(str(x0))
state.xs[0]

>>> state.ys = named.Dict()
>>> foo = state.ys['foo']
>>> print(str(foo))
state.ys['foo']

These addresses can now be used inside sample, observe and param statements. These named data structures even provide in-place methods that alias Pyro statements. For example:

>>> state = named.Object("state")
>>> loc = state.loc.param_(torch.zeros(1, requires_grad=True))
>>> scale = state.scale.param_(torch.ones(1, requires_grad=True))
>>> z = state.z.sample_(dist.Normal(loc, scale))
>>> obs = state.x.sample_(dist.Normal(loc, scale), obs=z)

For deeper examples of how these can be used in model code, see the Tree Data and Mixture examples.

Authors: Fritz Obermeyer, Alexander Rush

class Object(name)[source]

Bases: object

Object to hold immutable latent state.

This object can serve either as a container for nested latent state or as a placeholder to be replaced by a tensor via a named.sample, named.observe, or named.param statement. When used as a placeholder, Object objects take the place of strings in normal pyro.sample statements.

Parameters

name (str) – The name of the object.

Example:

state = named.Object("state")
state.x = 0
state.ys = named.List()
state.zs = named.Dict()
state.a.b.c.d.e.f.g = 0  # Creates a chain of named.Objects.

Warning

This data structure is write-once: data may be added but may not be mutated or removed. Trying to mutate this data structure may result in silent errors.

sample_(fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor

Calls the stochastic function fn with additional side-effects depending on name and the enclosing context (e.g. an inference algorithm). See Introduction to Pyro for a discussion.

Parameters
  • name – name of sample

  • fn – distribution class or function

  • obs – observed datum (optional; should only be used in context of inference) optionally specified in kwargs

  • obs_mask (bool or Tensor) – Optional boolean tensor mask of shape broadcastable with fn.batch_shape. If provided, events with mask=True will be conditioned on obs and remaining events will be imputed by sampling. This introduces a latent sample site named name + "_unobserved" which should be used by guides.

  • infer (dict) – Optional dictionary of inference parameters specified in kwargs. See inference documentation for details.

Returns

sample

param_(init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor

Saves the variable as a parameter in the param store. To interact with the param store or write to disk, see Parameters.

Parameters
  • name (str) – name of parameter

  • init_tensor (torch.Tensor or callable) – initial tensor or lazy callable that returns a tensor. For large tensors, it may be cheaper to write e.g. lambda: torch.randn(100000), which will only be evaluated on the initial statement.

  • constraint (torch.distributions.constraints.Constraint) – torch constraint, defaults to constraints.real.

  • event_dim (int) – (optional) number of rightmost dimensions unrelated to batching. Dimension to the left of this will be considered batch dimensions; if the param statement is inside a subsampled plate, then corresponding batch dimensions of the parameter will be correspondingly subsampled. If unspecified, all dimensions will be considered event dims and no subsampling will be performed.

Returns

A constrained parameter. The underlying unconstrained parameter is accessible via pyro.param(...).unconstrained(), where .unconstrained is a weakref attribute.

Return type

torch.Tensor

class List(name=None)[source]

Bases: list

List-like object to hold immutable latent state.

This must either be given a name when constructed:

latent = named.List("root")

or must be immediately stored in a named.Object:

latent = named.Object("root")
latent.xs = named.List()  # Must be bound to a Object before use.

Warning

This data structure is write-once: data may be added but may not be mutated or removed. Trying to mutate this data structure may result in silent errors.

add()[source]

Append one new named.Object.

Returns

a new latent object at the end

Return type

named.Object

class Dict(name=None)[source]

Bases: dict

Dict-like object to hold immutable latent state.

This must either be given a name when constructed:

latent = named.Dict("root")

or must be immediately stored in a named.Object:

latent = named.Object("root")
latent.xs = named.Dict()  # Must be bound to a Object before use.

Warning

This data structure is write-once: data may be added but may not be mutated or removed. Trying to mutate this data structure may result in silent errors.

Scoping

pyro.contrib.autoname.scoping contains the implementation of pyro.contrib.autoname.scope(), a tool for automatically appending a semantically meaningful prefix to names of sample sites.

class NameCountMessenger[source]

Bases: pyro.poutine.messenger.Messenger

NameCountMessenger is the implementation of pyro.contrib.autoname.name_count()

class ScopeMessenger(prefix=None, inner=None)[source]

Bases: pyro.poutine.messenger.Messenger

ScopeMessenger is the implementation of pyro.contrib.autoname.scope()

scope(fn=None, prefix=None, inner=None)[source]
Parameters
  • fn – a stochastic function (callable containing Pyro primitive calls)

  • prefix – a string to prepend to sample names (optional if fn is provided)

  • inner – switch to determine where duplicate name counters appear

Returns

fn decorated with a ScopeMessenger

scope prepends a prefix followed by a / to the name at a Pyro sample site. It works much like TensorFlow’s name_scope and variable_scope, and can be used as a context manager, a decorator, or a higher-order function.

scope is very useful for aligning compositional models with guides or data.

Example:

>>> @scope(prefix="a")
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

Example:

>>> def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

Scopes compose as expected, with outer scopes appearing before inner scopes in names:

>>> @scope(prefix="b")
... def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "b/a/x" in poutine.trace(model).get_trace()

When used as a decorator or higher-order function, scope will use the name of the input function as the prefix if no user-specified prefix is provided.

Example:

>>> @scope
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "model/x" in poutine.trace(model).get_trace()
name_count(fn=None)[source]

name_count is a very simple autonaming scheme that simply appends a suffix “__” plus a counter to any name that appears multiple tims in an execution. Only duplicate instances of a name get a suffix; the first instance is not modified.

Example:

>>> @name_count
... def model():
...     for i in range(3):
...         pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "x" in poutine.trace(model).get_trace()
>>> assert "x__1" in poutine.trace(model).get_trace()
>>> assert "x__2" in poutine.trace(model).get_trace()

name_count also composes with scope() by adding a suffix to duplicate scope entrances:

Example:

>>> @name_count
... def model():
...     for i in range(3):
...         with pyro.contrib.autoname.scope(prefix="a"):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a__1/x" in poutine.trace(model).get_trace()
>>> assert "a__2/x" in poutine.trace(model).get_trace()

Example:

>>> @name_count
... def model():
...     with pyro.contrib.autoname.scope(prefix="a"):
...         for i in range(3):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a/x__1" in poutine.trace(model).get_trace()
>>> assert "a/x__2" in poutine.trace(model).get_trace()