Automatic Name Generation¶
The pyro.contrib.autoname
module provides tools for automatically
generating unique, semantically meaningful names for sample sites.
- scope(fn=None, prefix=None, inner=None)[source]¶
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
prefix – a string to prepend to sample names (optional if
fn
is provided)inner – switch to determine where duplicate name counters appear
- Returns
fn
decorated with aScopeMessenger
scope
prepends a prefix followed by a/
to the name at a Pyro sample site. It works much like TensorFlow’sname_scope
andvariable_scope
, and can be used as a context manager, a decorator, or a higher-order function.scope
is very useful for aligning compositional models with guides or data.Example:
>>> @scope(prefix="a") ... def model(): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace()
Example:
>>> def model(): ... with scope(prefix="a"): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace()
Scopes compose as expected, with outer scopes appearing before inner scopes in names:
>>> @scope(prefix="b") ... def model(): ... with scope(prefix="a"): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "b/a/x" in poutine.trace(model).get_trace()
When used as a decorator or higher-order function,
scope
will use the name of the input function as the prefix if no user-specified prefix is provided.Example:
>>> @scope ... def model(): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "model/x" in poutine.trace(model).get_trace()
- name_count(fn=None)[source]¶
name_count
is a very simple autonaming scheme that simply appends a suffix “__” plus a counter to any name that appears multiple tims in an execution. Only duplicate instances of a name get a suffix; the first instance is not modified.Example:
>>> @name_count ... def model(): ... for i in range(3): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "x" in poutine.trace(model).get_trace() >>> assert "x__1" in poutine.trace(model).get_trace() >>> assert "x__2" in poutine.trace(model).get_trace()
name_count
also composes withscope()
by adding a suffix to duplicate scope entrances:Example:
>>> @name_count ... def model(): ... for i in range(3): ... with pyro.contrib.autoname.scope(prefix="a"): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace() >>> assert "a__1/x" in poutine.trace(model).get_trace() >>> assert "a__2/x" in poutine.trace(model).get_trace()
Example:
>>> @name_count ... def model(): ... with pyro.contrib.autoname.scope(prefix="a"): ... for i in range(3): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace() >>> assert "a/x__1" in poutine.trace(model).get_trace() >>> assert "a/x__2" in poutine.trace(model).get_trace()
- autoname(fn=None, name=None)[source]¶
Convenient wrapper of
AutonameMessenger
Assign unique names to random variables.
For a new varialbe use its declared name if given, otherwise use the distribution name:
sample("x", dist.Bernoulli ... ) # -> x sample(dist.Bernoulli ... ) # -> Bernoulli
For repeated variables names append the counter as a suffix:
sample(dist.Bernoulli ... ) # -> Bernoulli sample(dist.Bernoulli ... ) # -> Bernoulli1 sample(dist.Bernoulli ... ) # -> Bernoulli2
Functions and iterators can be used as a name scope:
@autoname def f1(): sample(dist.Bernoulli ... ) @autoname def f2(): f1() # -> f2/f1/Bernoulli f1() # -> f2/f1__1/Bernoulli sample(dist.Bernoulli ... ) # -> f2/Bernoulli @autoname(name="model") def f3(): for i in autoname(range(3), name="time"): # model/time/Bernoulli .. model/time__1/Bernoulli .. model/time__2/Bernoulli sample(dist.Bernoulli ... ) # model/time/f1/Bernoulli .. model/time__1/f1/Bernoulli .. model/time__2/f1/Bernoulli f1()
Or scopes can be added using the with statement:
def f4(): with autoname(name="prefix"): f1() # -> prefix/f1/Bernoulli f1() # -> prefix/f1__1/Bernoulli sample(dist.Bernoulli ... ) # -> prefix/Bernoulli
- sample(*args)[source]¶
- sample(name: str, fn, *args, **kwargs)
- sample(fn: pyro.distributions.distribution.Distribution, *args, **kwargs)
Named Data Structures¶
The pyro.contrib.named
module is a thin syntactic layer on top of Pyro. It
allows Pyro models to be written to look like programs with operating on Python
data structures like latent.x.sample_(...)
, rather than programs with
string-labeled statements like x = pyro.sample("x", ...)
.
This module provides three container data structures named.Object
,
named.List
, and named.Dict
. These data structures are intended to be
nested in each other. Together they track the address of each piece of data
in each data structure, so that this address can be used as a Pyro site. For
example:
>>> state = named.Object("state")
>>> print(str(state))
state
>>> z = state.x.y.z # z is just a placeholder.
>>> print(str(z))
state.x.y.z
>>> state.xs = named.List() # Create a contained list.
>>> x0 = state.xs.add()
>>> print(str(x0))
state.xs[0]
>>> state.ys = named.Dict()
>>> foo = state.ys['foo']
>>> print(str(foo))
state.ys['foo']
These addresses can now be used inside sample
, observe
and param
statements. These named data structures even provide in-place methods that
alias Pyro statements. For example:
>>> state = named.Object("state")
>>> loc = state.loc.param_(torch.zeros(1, requires_grad=True))
>>> scale = state.scale.param_(torch.ones(1, requires_grad=True))
>>> z = state.z.sample_(dist.Normal(loc, scale))
>>> obs = state.x.sample_(dist.Normal(loc, scale), obs=z)
For deeper examples of how these can be used in model code, see the Tree Data and Mixture examples.
Authors: Fritz Obermeyer, Alexander Rush
- class Object(name)[source]¶
Bases:
object
Object to hold immutable latent state.
This object can serve either as a container for nested latent state or as a placeholder to be replaced by a tensor via a named.sample, named.observe, or named.param statement. When used as a placeholder, Object objects take the place of strings in normal pyro.sample statements.
- Parameters
name (str) – The name of the object.
Example:
state = named.Object("state") state.x = 0 state.ys = named.List() state.zs = named.Dict() state.a.b.c.d.e.f.g = 0 # Creates a chain of named.Objects.
Warning
This data structure is write-once: data may be added but may not be mutated or removed. Trying to mutate this data structure may result in silent errors.
- sample_(fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.Tensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor ¶
Calls the stochastic function
fn
with additional side-effects depending onname
and the enclosing context (e.g. an inference algorithm). See Introduction to Pyro for a discussion.- Parameters
name – name of sample
fn – distribution class or function
obs – observed datum (optional; should only be used in context of inference) optionally specified in kwargs
obs_mask (bool or Tensor) – Optional boolean tensor mask of shape broadcastable with
fn.batch_shape
. If provided, events with mask=True will be conditioned onobs
and remaining events will be imputed by sampling. This introduces a latent sample site namedname + "_unobserved"
which should be used by guides.infer (dict) – Optional dictionary of inference parameters specified in kwargs. See inference documentation for details.
- Returns
sample
- param_(init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor ¶
Saves the variable as a parameter in the param store. To interact with the param store or write to disk, see Parameters.
- Parameters
name (str) – name of parameter
init_tensor (torch.Tensor or callable) – initial tensor or lazy callable that returns a tensor. For large tensors, it may be cheaper to write e.g.
lambda: torch.randn(100000)
, which will only be evaluated on the initial statement.constraint (torch.distributions.constraints.Constraint) – torch constraint, defaults to
constraints.real
.event_dim (int) – (optional) number of rightmost dimensions unrelated to batching. Dimension to the left of this will be considered batch dimensions; if the param statement is inside a subsampled plate, then corresponding batch dimensions of the parameter will be correspondingly subsampled. If unspecified, all dimensions will be considered event dims and no subsampling will be performed.
- Returns
A constrained parameter. The underlying unconstrained parameter is accessible via
pyro.param(...).unconstrained()
, where.unconstrained
is a weakref attribute.- Return type
- class List(name=None)[source]¶
Bases:
list
List-like object to hold immutable latent state.
This must either be given a name when constructed:
latent = named.List("root")
or must be immediately stored in a
named.Object
:latent = named.Object("root") latent.xs = named.List() # Must be bound to a Object before use.
Warning
This data structure is write-once: data may be added but may not be mutated or removed. Trying to mutate this data structure may result in silent errors.
- class Dict(name=None)[source]¶
Bases:
dict
Dict-like object to hold immutable latent state.
This must either be given a name when constructed:
latent = named.Dict("root")
or must be immediately stored in a
named.Object
:latent = named.Object("root") latent.xs = named.Dict() # Must be bound to a Object before use.
Warning
This data structure is write-once: data may be added but may not be mutated or removed. Trying to mutate this data structure may result in silent errors.
Scoping¶
pyro.contrib.autoname.scoping
contains the implementation of
pyro.contrib.autoname.scope()
, a tool for automatically appending
a semantically meaningful prefix to names of sample sites.
- class NameCountMessenger[source]¶
Bases:
pyro.poutine.messenger.Messenger
NameCountMessenger
is the implementation ofpyro.contrib.autoname.name_count()
- class ScopeMessenger(prefix=None, inner=None)[source]¶
Bases:
pyro.poutine.messenger.Messenger
ScopeMessenger
is the implementation ofpyro.contrib.autoname.scope()
- scope(fn=None, prefix=None, inner=None)[source]¶
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
prefix – a string to prepend to sample names (optional if
fn
is provided)inner – switch to determine where duplicate name counters appear
- Returns
fn
decorated with aScopeMessenger
scope
prepends a prefix followed by a/
to the name at a Pyro sample site. It works much like TensorFlow’sname_scope
andvariable_scope
, and can be used as a context manager, a decorator, or a higher-order function.scope
is very useful for aligning compositional models with guides or data.Example:
>>> @scope(prefix="a") ... def model(): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace()
Example:
>>> def model(): ... with scope(prefix="a"): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace()
Scopes compose as expected, with outer scopes appearing before inner scopes in names:
>>> @scope(prefix="b") ... def model(): ... with scope(prefix="a"): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "b/a/x" in poutine.trace(model).get_trace()
When used as a decorator or higher-order function,
scope
will use the name of the input function as the prefix if no user-specified prefix is provided.Example:
>>> @scope ... def model(): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "model/x" in poutine.trace(model).get_trace()
- name_count(fn=None)[source]¶
name_count
is a very simple autonaming scheme that simply appends a suffix “__” plus a counter to any name that appears multiple tims in an execution. Only duplicate instances of a name get a suffix; the first instance is not modified.Example:
>>> @name_count ... def model(): ... for i in range(3): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "x" in poutine.trace(model).get_trace() >>> assert "x__1" in poutine.trace(model).get_trace() >>> assert "x__2" in poutine.trace(model).get_trace()
name_count
also composes withscope()
by adding a suffix to duplicate scope entrances:Example:
>>> @name_count ... def model(): ... for i in range(3): ... with pyro.contrib.autoname.scope(prefix="a"): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace() >>> assert "a__1/x" in poutine.trace(model).get_trace() >>> assert "a__2/x" in poutine.trace(model).get_trace()
Example:
>>> @name_count ... def model(): ... with pyro.contrib.autoname.scope(prefix="a"): ... for i in range(3): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace() >>> assert "a/x__1" in poutine.trace(model).get_trace() >>> assert "a/x__2" in poutine.trace(model).get_trace()