Primitives

sample(name, fn, *args, **kwargs)[source]

Calls the stochastic function fn with additional side-effects depending on name and the enclosing context (e.g. an inference algorithm). See Intro I and Intro II for a discussion.

Parameters:
  • name – name of sample
  • fn – distribution class or function
  • obs – observed datum (optional; should only be used in context of inference) optionally specified in kwargs
  • infer (dict) – Optional dictionary of inference parameters specified in kwargs. See inference documentation for details.
Returns:

sample

param(name, *args, **kwargs)[source]

Saves the variable as a parameter in the param store. To interact with the param store or write to disk, see Parameters.

Parameters:name – name of parameter
Returns:parameter
module(name, nn_module, update_module_params=False)[source]

Takes a torch.nn.Module and registers its parameters with the ParamStore. In conjunction with the ParamStore save() and load() functionality, this allows the user to save and load modules.

Parameters:
  • name (str) – name of module
  • nn_module (torch.nn.Module) – the module to be registered with Pyro
  • update_module_params – determines whether Parameters in the PyTorch module get overridden with the values found in the ParamStore (if any). Defaults to False
Returns:

torch.nn.Module

random_module(name, nn_module, prior, *args, **kwargs)[source]

Places a prior over the parameters of the module nn_module. Returns a distribution (callable) over nn.Module`s, which upon calling returns a sampled `nn.Module.

See the Bayesian Regression tutorial for an example.

Parameters:
  • name (str) – name of pyro module
  • nn_module (torch.nn.Module) – the module to be registered with pyro
  • prior – pyro distribution, stochastic function, or python dict with parameter names as keys and respective distributions/stochastic functions as values.
Returns:

a callable which returns a sampled module

class irange(name, size, subsample_size=None, subsample=None, use_cuda=None)[source]

Non-vectorized version of iarange. See iarange for details.

Parameters:
  • name (str) – A name that will be used for this site in a Trace.
  • size (int) – The size of the collection being subsampled (like stop in builtin range()).
  • subsample_size (int) – Size of minibatches used in subsampling. Defaults to size.
  • subsample (Anything supporting len().) – Optional custom subsample for user-defined subsampling schemes. If specified, then subsample_size will be set to len(subsample).
  • use_cuda (bool) – Optional bool specifying whether to use cuda tensors for internal log_prob computations. Defaults to torch.Tensor.is_cuda.
Returns:

A reusable iterator yielding a sequence of integers.

Examples:

>>> for i in irange('data', 100, subsample_size=10):
...     if z[i]:  # Prevents vectorization.
...         obs = sample('obs_{}'.format(i), dist.Normal(loc, scale), obs=data[i])

See SVI Part II for an extended discussion.

class iarange(name, size=None, subsample_size=None, subsample=None, dim=None, use_cuda=None)[source]

Context manager for conditionally independent ranges of variables.

iarange is similar to torch.arange() in that it yields an array of indices by which other tensors can be indexed. iarange differs from torch.arange() in that it also informs inference algorithms that the variables being indexed are conditionally independent. To do this, iarange is a provided as context manager rather than a function, and users must guarantee that all computation within an iarange context is conditionally independent:

with iarange("name", size) as ind:
    # ...do conditionally independent stuff with ind...

Additionally, iarange can take advantage of the conditional independence assumptions by subsampling the indices and informing inference algorithms to scale various computed values. This is typically used to subsample minibatches of data:

with iarange("data", len(data), subsample_size=100) as ind:
    batch = data[ind]
    assert len(batch) == 100

By default subsample_size=False and this simply yields a torch.arange(0, size). If 0 < subsample_size <= size this yields a single random batch of indices of size subsample_size and scales all log likelihood terms by size/batch_size, within this context.

Warning

This is only correct if all computation is conditionally independent within the context.

Parameters:
  • name (str) – A unique name to help inference algorithms match iarange sites between models and guides.
  • size (int) – Optional size of the collection being subsampled (like stop in builtin range).
  • subsample_size (int) – Size of minibatches used in subsampling. Defaults to size.
  • subsample (Anything supporting len().) – Optional custom subsample for user-defined subsampling schemes. If specified, then subsample_size will be set to len(subsample).
  • dim (int) – An optional dimension to use for this independence index. If specified, dim should be negative, i.e. should index from the right. If not specified, dim is set to the rightmost dim that is left of all enclosing iarange contexts.
  • use_cuda (bool) – Optional bool specifying whether to use cuda tensors for subsample and log_prob. Defaults to torch.Tensor.is_cuda.
Returns:

A reusabe context manager yielding a single 1-dimensional torch.Tensor of indices.

Examples:

>>> # This version simply declares independence:
>>> with iarange('data'):
...     obs = sample('obs', dist.Normal(loc, scale), obs=data)
>>> # This version subsamples data in vectorized way:
>>> with iarange('data', 100, subsample_size=10) as ind:
...     obs = sample('obs', dist.Normal(loc, scale), obs=data[ind])
>>> # This wraps a user-defined subsampling method for use in pyro:
>>> ind = torch.randint(0, 100, (10,)).long() # custom subsample
>>> with iarange('data', 100, subsample=ind):
...     obs = sample('obs', dist.Normal(loc, scale), obs=data[ind])
>>> # This reuses two different independence contexts.
>>> x_axis = iarange('outer', 320, dim=-1)
>>> y_axis = iarange('inner', 200, dim=-2)
>>> with x_axis:
...     x_noise = sample("x_noise", dist.Normal(loc, scale).expand_by([320]))
>>> with y_axis:
...     y_noise = sample("y_noise", dist.Normal(loc, scale).expand_by([200, 1]))
>>> with x_axis, y_axis:
...     xy_noise = sample("xy_noise", dist.Normal(loc, scale).expand_by([200, 320]))

See SVI Part II for an extended discussion.

get_param_store()[source]

Returns the ParamStore

clear_param_store()[source]

Clears the ParamStore. This is especially useful if you’re working in a REPL.

validation_enabled(*args, **kwds)[source]

Context manager that is useful when temporarily enabling/disabling validation checks.

Parameters:is_validate (bool) – (optional; defaults to True) temporary validation check override.
enable_validation(is_validate=True)[source]

Enable or disable validation checks in Pyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution arguments and support values, etc. which is useful for debugging. Since some of these checks may be expensive, we recommend turning this off for mature models.

Parameters:is_validate (bool) – (optional; defaults to True) whether to enable validation checks.
compile(fn=None, **jit_options)[source]

Drop-in replacement for torch.jit.compile() that works with Pyro functions that call pyro.param().

The actual compilation artifact is stored in the compiled attribute of the output. Call diagnostic methods on this attribute.

Example:

def model(x):
    scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive)
    return pyro.sample("y", dist.Normal(x, scale))

@pyro.ops.jit.compile(nderivs=1)
def model_log_prob_fn(x, y):
    cond_model = pyro.condition(model, data={"y": y})
    tr = pyro.poutine.trace(cond_model).get_trace(x)
    return tr.log_prob_sum()