get_param_store() pyro.params.param_store.ParamStoreDict[source]

Returns the global ParamStoreDict.

clear_param_store() None[source]

Clears the global ParamStoreDict.

This is especially useful if you’re working in a REPL. We recommend calling this before each training loop (to avoid leaking parameters from past models), and before each unit test (to avoid leaking parameters across tests).

param(name: str, init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor[source]

Saves the variable as a parameter in the param store. To interact with the param store or write to disk, see Parameters.

  • name (str) – name of parameter

  • init_tensor (torch.Tensor or callable) – initial tensor or lazy callable that returns a tensor. For large tensors, it may be cheaper to write e.g. lambda: torch.randn(100000), which will only be evaluated on the initial statement.

  • constraint (torch.distributions.constraints.Constraint) – torch constraint, defaults to constraints.real.

  • event_dim (int) – (optional) number of rightmost dimensions unrelated to batching. Dimension to the left of this will be considered batch dimensions; if the param statement is inside a subsampled plate, then corresponding batch dimensions of the parameter will be correspondingly subsampled. If unspecified, all dimensions will be considered event dims and no subsampling will be performed.


A constrained parameter. The underlying unconstrained parameter is accessible via pyro.param(...).unconstrained(), where .unconstrained is a weakref attribute.

Return type


sample(name: str, fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor[source]

Calls the stochastic function fn with additional side-effects depending on name and the enclosing context (e.g. an inference algorithm). See Introduction to Pyro for a discussion.

  • name – name of sample

  • fn – distribution class or function

  • obs – observed datum (optional; should only be used in context of inference) optionally specified in kwargs

  • obs_mask (bool or Tensor) – Optional boolean tensor mask of shape broadcastable with fn.batch_shape. If provided, events with mask=True will be conditioned on obs and remaining events will be imputed by sampling. This introduces a latent sample site named name + "_unobserved" which should be used by guides.

  • infer (dict) – Optional dictionary of inference parameters specified in kwargs. See inference documentation for details.



factor(name: str, log_factor: torch.Tensor, *, has_rsample: Optional[bool] = None) None[source]

Factor statement to add arbitrary log probability factor to a probabilisitic model.


When using factor statements in guides, you’ll need to specify whether the factor statement originated from fully reparametrized sampling (e.g. the Jacobian determinant of a transformation of a reparametrized variable) or from nonreparameterized sampling (e.g. discrete samples). For the fully reparametrized case, set has_rsample=True; for the nonreparametrized case, set has_rsample=False. This is needed only in guides, not in models.

  • name (str) – Name of the trivial sample

  • log_factor (torch.Tensor) – A possibly batched log probability factor.

  • has_rsample (bool) – Whether the log_factor arose from a fully reparametrized distribution. Defaults to False when used in models, but must be specified for use in guides.

deterministic(name: str, value: torch.Tensor, event_dim: Optional[int] = None) torch.Tensor[source]

Deterministic statement to add a Delta site with name name and value value to the trace. This is useful when we want to record values which are completely determined by their parents. For example:

x = pyro.sample("x", dist.Normal(0, 1))
x2 = pyro.deterministic("x2", x ** 2)


The site does not affect the model density. This currently converts to a sample() statement, but may change in the future.

  • name (str) – Name of the site.

  • value (torch.Tensor) – Value of the site.

  • event_dim (int) – Optional event dimension, defaults to value.ndim.

subsample(data: torch.Tensor, event_dim: int) torch.Tensor[source]

Subsampling statement to subsample data tensors based on enclosing plate s.

This is typically called on arguments to model() when subsampling is performed automatically by plate s by passing either the subsample or subsample_size kwarg. For example the following are equivalent:

# Version 1. using indexing
def model(data):
    with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind:
        data = data[ind]
        # ...

# Version 2. using pyro.subsample()
def model(data):
    with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()):
        data = pyro.subsample(data, event_dim=0)
        # ...
  • data (Tensor) – A tensor of batched data.

  • event_dim (int) – The event dimension of the data tensor. Dimensions to the left are considered batch dimensions.


A subsampled version of data

Return type


class plate(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]

Bases: pyro.poutine.plate_messenger.PlateMessenger

Construct for conditionally independent sequences of variables.

plate can be used either sequentially as a generator or in parallel as a context manager (formerly irange and iarange, respectively).

Sequential plate is similar to range() in that it generates a sequence of values.

Vectorized plate is similar to torch.arange() in that it yields an array of indices by which other tensors can be indexed. plate differs from torch.arange() in that it also informs inference algorithms that the variables being indexed are conditionally independent. To do this, plate is a provided as context manager rather than a function, and users must guarantee that all computation within an plate context is conditionally independent:

with pyro.plate("name", size) as ind:
    # conditionally independent stuff with ind...

Additionally, plate can take advantage of the conditional independence assumptions by subsampling the indices and informing inference algorithms to scale various computed values. This is typically used to subsample minibatches of data:

with pyro.plate("data", len(data), subsample_size=100) as ind:
    batch = data[ind]
    assert len(batch) == 100

By default subsample_size=False and this simply yields a torch.arange(0, size). If 0 < subsample_size <= size this yields a single random batch of indices of size subsample_size and scales all log likelihood terms by size/batch_size, within this context.


This is only correct if all computation is conditionally independent within the context.

  • name (str) – A unique name to help inference algorithms match plate sites between models and guides.

  • size (int) – Optional size of the collection being subsampled (like stop in builtin range).

  • subsample_size (int) – Size of minibatches used in subsampling. Defaults to size.

  • subsample (Anything supporting len().) – Optional custom subsample for user-defined subsampling schemes. If specified, then subsample_size will be set to len(subsample).

  • dim (int) – An optional dimension to use for this independence index. If specified, dim should be negative, i.e. should index from the right. If not specified, dim is set to the rightmost dim that is left of all enclosing plate contexts.

  • use_cuda (bool) – DEPRECATED, use the device arg instead. Optional bool specifying whether to use cuda tensors for subsample and log_prob. Defaults to torch.Tensor.is_cuda.

  • device (str) – Optional keyword specifying which device to place the results of subsample and log_prob on. By default, results are placed on the same device as the default tensor.


A reusabe context manager yielding a single 1-dimensional torch.Tensor of indices.


>>> # This version declares sequential independence and subsamples data:
>>> for i in pyro.plate('data', 100, subsample_size=10):
...     if z[i]:  # Control flow in this example prevents vectorization.
...         obs = pyro.sample(f'obs_{i}', dist.Normal(loc, scale),
...                           obs=data[i])
>>> # This version declares vectorized independence:
>>> with pyro.plate('data'):
...     obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data)
>>> # This version subsamples data in vectorized way:
>>> with pyro.plate('data', 100, subsample_size=10) as ind:
...     obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data[ind])
>>> # This wraps a user-defined subsampling method for use in pyro:
>>> ind = torch.randint(0, 100, (10,)).long() # custom subsample
>>> with pyro.plate('data', 100, subsample=ind):
...     obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data[ind])
>>> # This reuses two different independence contexts.
>>> x_axis = pyro.plate('outer', 320, dim=-1)
>>> y_axis = pyro.plate('inner', 200, dim=-2)
>>> with x_axis:
...     x_noise = pyro.sample("x_noise", dist.Normal(loc, scale))
...     assert x_noise.shape == (320,)
>>> with y_axis:
...     y_noise = pyro.sample("y_noise", dist.Normal(loc, scale))
...     assert y_noise.shape == (200, 1)
>>> with x_axis, y_axis:
...     xy_noise = pyro.sample("xy_noise", dist.Normal(loc, scale))
...     assert xy_noise.shape == (200, 320)

See SVI Part II for an extended discussion.

plate_stack(prefix: str, sizes: Sequence[int], rightmost_dim: int = - 1) Iterator[None][source]

Create a contiguous stack of plate s with dimensions:

rightmost_dim - len(sizes), ..., rightmost_dim
  • prefix (str) – Name prefix for plates.

  • sizes (iterable) – An iterable of plate sizes.

  • rightmost_dim (int) – The rightmost dim, counting from the right.

module(name: str, nn_module: torch.nn.modules.module.Module, update_module_params: bool = False) torch.nn.modules.module.Module[source]

Registers all parameters of a torch.nn.Module with Pyro’s param_store. In conjunction with the ParamStoreDict save() and load() functionality, this allows the user to save and load modules.


Consider instead using PyroModule, a newer alternative to pyro.module() that has better support for: jitting, serving in C++, and converting parameters to random variables. For details see the Modules Tutorial .

  • name (str) – name of module

  • nn_module (torch.nn.Module) – the module to be registered with Pyro

  • update_module_params – determines whether Parameters in the PyTorch module get overridden with the values found in the ParamStore (if any). Defaults to False



random_module(name, nn_module, prior, *args, **kwargs)[source]


The random_module primitive is deprecated, and will be removed in a future release. Use PyroModule instead to to create Bayesian modules from torch.nn.Module instances. See the Bayesian Regression tutorial for an example.

DEPRECATED Places a prior over the parameters of the module nn_module. Returns a distribution (callable) over nn.Modules, which upon calling returns a sampled nn.Module.

  • name (str) – name of pyro module

  • nn_module (torch.nn.Module) – the module to be registered with pyro

  • prior – pyro distribution, stochastic function, or python dict with parameter names as keys and respective distributions/stochastic functions as values.


a callable which returns a sampled module

barrier(data: torch.Tensor) torch.Tensor[source]

EXPERIMENTAL Ensures all values in data are ground, rather than lazy funsor values. This is useful in combination with pyro.poutine.collapse().

enable_validation(is_validate: bool = True) None[source]

Enable or disable validation checks in Pyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution arguments and support values, detecting incorrect use of ELBO and MCMC. Since some of these checks may be expensive, you may want to disable validation of mature models to speed up inference.

The default behavior mimics Python’s assert statement: validation is on by default, but is disabled if Python is run in optimized mode (via python -O). Equivalently, the default behavior depends on Python’s global __debug__ value via pyro.enable_validation(__debug__).

Validation is temporarily disabled during jit compilation, for all inference algorithms that support the PyTorch jit. We recommend developing models with non-jitted inference algorithms to ease debugging, then optionally moving to jitted inference once a model is correct.


is_validate (bool) – (optional; defaults to True) whether to enable validation checks.

validation_enabled(is_validate: bool = True) Iterator[None][source]

Context manager that is useful when temporarily enabling/disabling validation checks.


is_validate (bool) – (optional; defaults to True) temporary validation check override.

trace(fn=None, ignore_warnings=False, jit_options=None)[source]

Lazy replacement for torch.jit.trace() that works with Pyro functions that call pyro.param().

The actual compilation artifact is stored in the compiled attribute of the output. Call diagnostic methods on this attribute.


def model(x):
    scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive)
    return pyro.sample("y", dist.Normal(x, scale))

def model_log_prob_fn(x, y):
    cond_model = pyro.condition(model, data={"y": y})
    tr = pyro.poutine.trace(cond_model).get_trace(x)
    return tr.log_prob_sum()
  • fn (callable) – The function to be traced.

  • ignore_warnins (bool) – Whether to ignore jit warnings.

  • jit_options (dict) – Optional dict of options to pass to torch.jit.trace() , e.g. {"optimize": False}.