# Neural Networks¶

The module pyro.nn provides implementations of neural network modules that are useful in the context of deep probabilistic programming.

## Pyro Modules¶

Pyro includes a class PyroModule, a subclass of torch.nn.Module, whose attributes can be modified by Pyro effects. To create a poutine-aware attribute, use either the PyroParam struct or the PyroSample struct:

my_module = PyroModule()
my_module.x = PyroParam(torch.tensor(1.), constraint=constraints.positive)
my_module.y = PyroSample(dist.Normal(0, 1))

class PyroParam[source]

Declares a Pyro-managed learnable attribute of a PyroModule, similar to pyro.param.

This can be used either to set attributes of PyroModule instances:

assert isinstance(my_module, PyroModule)
my_module.x = PyroParam(torch.zeros(4))                   # eager
my_module.y = PyroParam(lambda: torch.randn(4))           # lazy
my_module.z = PyroParam(torch.ones(4),                    # eager
constraint=constraints.positive,
event_dim=1)


or EXPERIMENTALLY as a decorator on lazy initialization properties:

class MyModule(PyroModule):
@PyroParam
def x(self):

@PyroParam
def y(self):

@PyroParam(constraint=constraints.real, event_dim=1)
def z(self):

def forward(self):
return self.x + self.y + self.z  # accessed like a @property

Parameters: init_value (torch.Tensor or callable returning a torch.Tensor or None) – Either a tensor for eager initialization, a callable for lazy initialization, or None for use as a decorator. constraint (Constraint) – torch constraint, defaults to constraints.real. event_dim (int) – (optional) number of rightmost dimensions unrelated to baching. Dimension to the left of this will be considered batch dimensions; if the param statement is inside a subsampled plate, then corresponding batch dimensions of the parameter will be correspondingly subsampled. If unspecified, all dimensions will be considered event dims and no subsampling will be performed.
class PyroSample(prior)[source]

Declares a Pyro-managed random attribute of a PyroModule, similar to pyro.sample.

This can be used either to set attributes of PyroModule instances:

assert isinstance(my_module, PyroModule)
my_module.x = PyroSample(Normal(0, 1))                    # independent
my_module.y = PyroSample(lambda self: Normal(self.x, 1))  # dependent


or EXPERIMENTALLY as a decorator on lazy initialization methods:

class MyModule(PyroModule):
@PyroSample
def x(self):
return Normal(0, 1)       # independent

@PyroSample
def y(self):
return Normal(self.x, 1)  # dependent

def forward(self):
return self.y             # accessed like a @property

Parameters: prior – distribution object or function that inputs the PyroModule instance self and returns a distribution object.
class PyroModule(name='')[source]

Bases: torch.nn.modules.module.Module

Subclass of torch.nn.Module whose attributes can be modified by Pyro effects. Attributes can be set using helpers PyroParam and PyroSample , and methods can be decorated by pyro_method() .

Parameters

To create a Pyro-managed parameter attribute, set that attribute using either torch.nn.Parameter (for unconstrained parameters) or PyroParam (for constrained parameters). Reading that attribute will then trigger a pyro.param statement. For example:

# Create Pyro-managed parameter attributes.
my_module = PyroModule()
my_module.loc = nn.Parameter(torch.tensor(0.))
my_module.scale = PyroParam(torch.tensor(1.),
constraint=constraints.positive)
loc = my_module.loc  # Triggers a pyro.param statement.
scale = my_module.scale  # Triggers another pyro.param statement.


Note that, unlike normal torch.nn.Module s, PyroModule s should not be registered with pyro.module statements. PyroModule s can contain other PyroModule s and normal torch.nn.Module s. Accessing a normal torch.nn.Module attribute of a PyroModule triggers a pyro.module statement. If multiple PyroModule s appear in a single Pyro model or guide, they should be included in a single root PyroModule for that model.

PyroModule s synchronize data with the param store at each setattr, getattr, and delattr event, based on the nested name of an attribute:

• Setting mod.x = x_init tries to read x from the param store. If a value is found in the param store, that value is copied into mod and x_init is ignored; otherwise x_init is copied into both mod and the param store.
• Reading mod.x tries to read x from the param store. If a value is found in the param store, that value is copied into mod; otherwise mod’s value is copied into the param store. Finally mod and the param store agree on a single value to return.
• Deleting del mod.x removes a value from both mod and the param store.

Note two PyroModule of the same name will both synchronize with the global param store and thus contain the same data. When creating a PyroModule, then deleting it, then creating another with the same name, the latter will be populated with the former’s data from the param store. To avoid this persistence, either pyro.clear_param_store() or call clear() before deleting a PyroModule .

PyroModule s can be saved and loaded either directly using torch.save() / torch.load() or indirectly using the param store’s save() / load() . Note that torch.load() will be overridden by any values in the param store, so it is safest to pyro.clear_param_store() before loading.

Samples

To create a Pyro-managed random attribute, set that attribute using the PyroSample helper, specifying a prior distribution. Reading that attribute will then trigger a pyro.sample statement. For example:

# Create Pyro-managed random attributes.
my_module.x = PyroSample(dist.Normal(0, 1))
my_module.y = PyroSample(lambda self: dist.Normal(self.loc, self.scale))

# Sample the attributes.
x = my_module.x  # Triggers a pyro.sample statement.
y = my_module.y  # Triggers one pyro.sample + two pyro.param statements.


Sampling is cached within each invocation of .__call__() or method decorated by pyro_method() . Because sample statements can appear only once in a Pyro trace, you should ensure that traced access to sample attributes is wrapped in a single invocation of .__call__() or method decorated by pyro_method() .

To make an existing module probabilistic, you can create a subclass and overwrite some parameters with PyroSample s:

class RandomLinear(nn.Linear, PyroModule):  # used as a mixin
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features)
self.weight = PyroSample(
lambda self: dist.Normal(0, 1)
.expand([self.out_features,
self.in_features])
.to_event(2))


Mixin classes

PyroModule can be used as a mixin class, and supports simple syntax for dynamically creating mixins, for example the following are equivalent:

# Version 1. create a named mixin class
class PyroLinear(nn.Linear, PyroModule):
pass

m.linear = PyroLinear(m, n)

# Version 2. create a dynamic mixin class
m.linear = PyroModule[nn.Linear](m, n)


This notation can be used recursively to create Bayesian modules, e.g.:

model = PyroModule[nn.Sequential](
PyroModule[nn.Linear](28 * 28, 100),
PyroModule[nn.Sigmoid](),
PyroModule[nn.Linear](100, 100),
PyroModule[nn.Sigmoid](),
PyroModule[nn.Linear](100, 10),
)
assert isinstance(model, nn.Sequential)
assert isinstance(model, PyroModule)

# Now we can be Bayesian about weights in the first layer.
model[0].weight = PyroSample(
prior=dist.Normal(0, 1).expand([28 * 28, 100]).to_event(2))
guide = AutoDiagonalNormal(model)


Note that PyroModule[...] does not recursively mix in PyroModule to submodules of the input Module; hence we needed to wrap each submodule of the nn.Sequential above.

Parameters: name (str) – Optional name for a root PyroModule. This is ignored in sub-PyroModules of another PyroModule.
add_module(name, module)[source]

Adds a child module to the current module.

named_pyro_params(prefix='', recurse=True)[source]

Returns an iterator over PyroModule parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters: prefix (str) – prefix to prepend to all parameter names. recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. a generator which yields tuples containing the name and parameter
pyro_method(fn)[source]

Decorator for top-level methods of a PyroModule to enable pyro effects and cache pyro.sample statements.

This should be applied to all public methods that read Pyro-managed attributes, but is not needed for .forward().

clear(mod)[source]

Removes data from both a PyroModule and the param store.

Parameters: mod (PyroModule) – A module to clear.
to_pyro_module_(m, recurse=True)[source]

Converts an ordinary torch.nn.Module instance to a PyroModule in-place.

This is useful for adding Pyro effects to third-party modules: no third-party code needs to be modified. For example:

model = nn.Sequential(
nn.Linear(28 * 28, 100),
nn.Sigmoid(),
nn.Linear(100, 100),
nn.Sigmoid(),
nn.Linear(100, 10),
)
to_pyro_module_(model)
assert isinstance(model, PyroModule[nn.Sequential])
assert isinstance(model[0], PyroModule[nn.Linear])

# Now we can attempt to be fully Bayesian:
for m in model.modules():
for name, value in list(m.named_parameters(recurse=False)):
setattr(m, name, PyroSample(prior=dist.Normal(0, 1)
.expand(value.shape)
.to_event(value.dim())))
guide = AutoDiagonalNormal(model)

Parameters: m (torch.nn.Module) – A module instance. recurse (bool) – Whether to convert submodules to PyroModules .

## AutoRegressiveNN¶

class AutoRegressiveNN(input_dim, hidden_dims, param_dims=[1, 1], permutation=None, skip_connections=False, nonlinearity=ReLU())[source]

An implementation of a MADE-like auto-regressive neural network.

Example usage:

>>> x = torch.randn(100, 10)
>>> arn = AutoRegressiveNN(10, [50], param_dims=[1])
>>> p = arn(x)  # 1 parameters of size (100, 10)
>>> arn = AutoRegressiveNN(10, [50], param_dims=[1, 1])
>>> m, s = arn(x) # 2 parameters of size (100, 10)
>>> arn = AutoRegressiveNN(10, [50], param_dims=[1, 5, 3])
>>> a, b, c = arn(x) # 3 parameters of sizes, (100, 1, 10), (100, 5, 10), (100, 3, 10)

Parameters: input_dim (int) – the dimensionality of the input variable hidden_dims (list[int]) – the dimensionality of the hidden units per layer param_dims (list[int]) – shape the output into parameters of dimension (p_n, input_dim) for p_n in param_dims when p_n > 1 and dimension (input_dim) when p_n == 1. The default is [1, 1], i.e. output two parameters of dimension (input_dim), which is useful for inverse autoregressive flow. permutation (torch.LongTensor) – an optional permutation that is applied to the inputs and controls the order of the autoregressive factorization. in particular for the identity permutation the autoregressive structure is such that the Jacobian is upper triangular. By default this is chosen at random. skip_connections (bool) – Whether to add skip connections from the input to the output. nonlinearity (torch.nn.module) – The nonlinearity to use in the feedforward network such as torch.nn.ReLU(). Note that no nonlinearity is applied to the final network output, so the output is an unbounded real number.

Reference:

MADE: Masked Autoencoder for Distribution Estimation [arXiv:1502.03509] Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle

forward(x)[source]

The forward method

## DenseNN¶

class DenseNN(input_dim, hidden_dims, param_dims=[1, 1], nonlinearity=ReLU())[source]

An implementation of a simple dense feedforward network, for use in, e.g., some conditional flows such as pyro.distributions.transforms.ConditionalPlanarFlow and other unconditional flows such as pyro.distributions.transforms.AffineCoupling that do not require an autoregressive network.

Example usage:

>>> input_dim = 10
>>> context_dim = 5
>>> z = torch.rand(100, context_dim)
>>> nn = DenseNN(context_dim, [50], param_dims=[1, input_dim, input_dim])
>>> a, b, c = nn(z)  # parameters of size (100, 1), (100, 10), (100, 10)

Parameters: input_dim (int) – the dimensionality of the input hidden_dims (list[int]) – the dimensionality of the hidden units per layer param_dims (list[int]) – shape the output into parameters of dimension (p_n,) for p_n in param_dims when p_n > 1 and dimension () when p_n == 1. The default is [1, 1], i.e. output two parameters of dimension (). nonlinearity (torch.nn.module) – The nonlinearity to use in the feedforward network such as torch.nn.ReLU(). Note that no nonlinearity is applied to the final network output, so the output is an unbounded real number.
forward(x)[source]

The forward method

## ConditionalAutoRegressiveNN¶

class ConditionalAutoRegressiveNN(input_dim, context_dim, hidden_dims, param_dims=[1, 1], permutation=None, skip_connections=False, nonlinearity=ReLU())[source]

Bases: torch.nn.modules.module.Module

An implementation of a MADE-like auto-regressive neural network that can input an additional context variable. (See Reference [2] Section 3.3 for an explanation of how the conditional MADE architecture works.)

Example usage:

>>> x = torch.randn(100, 10)
>>> y = torch.randn(100, 5)
>>> arn = ConditionalAutoRegressiveNN(10, 5, [50], param_dims=[1])
>>> p = arn(x, context=y)  # 1 parameters of size (100, 10)
>>> arn = ConditionalAutoRegressiveNN(10, 5, [50], param_dims=[1, 1])
>>> m, s = arn(x, context=y) # 2 parameters of size (100, 10)
>>> arn = ConditionalAutoRegressiveNN(10, 5, [50], param_dims=[1, 5, 3])
>>> a, b, c = arn(x, context=y) # 3 parameters of sizes, (100, 1, 10), (100, 5, 10), (100, 3, 10)

Parameters: input_dim (int) – the dimensionality of the input variable context_dim (int) – the dimensionality of the context variable hidden_dims (list[int]) – the dimensionality of the hidden units per layer param_dims (list[int]) – shape the output into parameters of dimension (p_n, input_dim) for p_n in param_dims when p_n > 1 and dimension (input_dim) when p_n == 1. The default is [1, 1], i.e. output two parameters of dimension (input_dim), which is useful for inverse autoregressive flow. permutation (torch.LongTensor) – an optional permutation that is applied to the inputs and controls the order of the autoregressive factorization. in particular for the identity permutation the autoregressive structure is such that the Jacobian is upper triangular. By default this is chosen at random. skip_connections (bool) – Whether to add skip connections from the input to the output. nonlinearity (torch.nn.module) – The nonlinearity to use in the feedforward network such as torch.nn.ReLU(). Note that no nonlinearity is applied to the final network output, so the output is an unbounded real number.

Reference:

1. MADE: Masked Autoencoder for Distribution Estimation [arXiv:1502.03509] Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle

2. Inference Networks for Sequential Monte Carlo in Graphical Models [arXiv:1602.06701] Brooks Paige, Frank Wood

forward(x, context=None)[source]

The forward method

get_permutation()[source]

Get the permutation applied to the inputs (by default this is chosen at random)

## ConditionalDenseNN¶

class ConditionalDenseNN(input_dim, context_dim, hidden_dims, param_dims=[1, 1], nonlinearity=ReLU())[source]

Bases: torch.nn.modules.module.Module

An implementation of a simple dense feedforward network taking a context variable, for use in, e.g., some conditional flows such as pyro.distributions.transforms.ConditionalAffineCoupling.

Example usage:

>>> input_dim = 10
>>> context_dim = 5
>>> x = torch.rand(100, input_dim)
>>> z = torch.rand(100, context_dim)
>>> nn = ConditionalDenseNN(input_dim, context_dim, [50], param_dims=[1, input_dim, input_dim])
>>> a, b, c = nn(x, context=z)  # parameters of size (100, 1), (100, 10), (100, 10)

Parameters: input_dim (int) – the dimensionality of the input context_dim (int) – the dimensionality of the context variable hidden_dims (list[int]) – the dimensionality of the hidden units per layer param_dims (list[int]) – shape the output into parameters of dimension (p_n,) for p_n in param_dims when p_n > 1 and dimension () when p_n == 1. The default is [1, 1], i.e. output two parameters of dimension (). nonlinearity (torch.nn.Module) – The nonlinearity to use in the feedforward network such as torch.nn.ReLU(). Note that no nonlinearity is applied to the final network output, so the output is an unbounded real number.
forward(x, context)[source]

The forward method