# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
from typing import List, Sequence, Union
import torch
[docs]class ConditionalDenseNN(torch.nn.Module):
"""
An implementation of a simple dense feedforward network taking a context variable, for use in, e.g.,
some conditional flows such as :class:`pyro.distributions.transforms.ConditionalAffineCoupling`.
Example usage:
>>> input_dim = 10
>>> context_dim = 5
>>> x = torch.rand(100, input_dim)
>>> z = torch.rand(100, context_dim)
>>> nn = ConditionalDenseNN(input_dim, context_dim, [50], param_dims=[1, input_dim, input_dim])
>>> a, b, c = nn(x, context=z) # parameters of size (100, 1), (100, 10), (100, 10)
:param input_dim: the dimensionality of the input
:type input_dim: int
:param context_dim: the dimensionality of the context variable
:type context_dim: int
:param hidden_dims: the dimensionality of the hidden units per layer
:type hidden_dims: list[int]
:param param_dims: shape the output into parameters of dimension (p_n,) for p_n in param_dims
when p_n > 1 and dimension () when p_n == 1. The default is [1, 1], i.e. output two parameters of dimension ().
:type param_dims: list[int]
:param nonlinearity: The nonlinearity to use in the feedforward network such as torch.nn.ReLU(). Note that no
nonlinearity is applied to the final network output, so the output is an unbounded real number.
:type nonlinearity: torch.nn.Module
"""
def __init__(
self,
input_dim: int,
context_dim: int,
hidden_dims: List[int],
param_dims: List[int] = [1, 1],
nonlinearity: torch.nn.Module = torch.nn.ReLU(),
) -> None:
super().__init__()
self.input_dim = input_dim
self.context_dim = context_dim
self.hidden_dims = hidden_dims
self.param_dims = param_dims
self.count_params = len(param_dims)
self.output_multiplier = sum(param_dims)
# Calculate the indices on the output corresponding to each parameter
ends = torch.cumsum(torch.tensor(param_dims), dim=0)
starts = torch.cat((torch.zeros(1).type_as(ends), ends[:-1]))
self.param_slices = [slice(s.item(), e.item()) for s, e in zip(starts, ends)]
# Create masked layers
layers = [torch.nn.Linear(input_dim + context_dim, hidden_dims[0])]
for i in range(1, len(hidden_dims)):
layers.append(torch.nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
layers.append(torch.nn.Linear(hidden_dims[-1], self.output_multiplier))
self.layers = torch.nn.ModuleList(layers)
# Save the nonlinearity
self.f = nonlinearity
[docs] def forward(
self, x: torch.Tensor, context: torch.Tensor
) -> Union[Sequence[torch.Tensor], torch.Tensor]:
# We must be able to broadcast the size of the context over the input
context = context.expand(x.size()[:-1] + (context.size(-1),))
x = torch.cat([context, x], dim=-1)
return self._forward(x)
def _forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]:
"""
The forward method
"""
h = x
for layer in self.layers[:-1]:
h = self.f(layer(h))
h = self.layers[-1](h)
# Shape the output, squeezing the parameter dimension if all ones
if self.output_multiplier == 1:
return h
else:
h = h.reshape(list(x.size()[:-1]) + [self.output_multiplier])
if self.count_params == 1:
return h
else:
return tuple([h[..., s] for s in self.param_slices])
[docs]class DenseNN(ConditionalDenseNN):
"""
An implementation of a simple dense feedforward network, for use in, e.g., some conditional flows such as
:class:`pyro.distributions.transforms.ConditionalPlanarFlow` and other unconditional flows such as
:class:`pyro.distributions.transforms.AffineCoupling` that do not require an autoregressive network.
Example usage:
>>> input_dim = 10
>>> context_dim = 5
>>> z = torch.rand(100, context_dim)
>>> nn = DenseNN(context_dim, [50], param_dims=[1, input_dim, input_dim])
>>> a, b, c = nn(z) # parameters of size (100, 1), (100, 10), (100, 10)
:param input_dim: the dimensionality of the input
:type input_dim: int
:param hidden_dims: the dimensionality of the hidden units per layer
:type hidden_dims: list[int]
:param param_dims: shape the output into parameters of dimension (p_n,) for p_n in param_dims
when p_n > 1 and dimension () when p_n == 1. The default is [1, 1], i.e. output two parameters of dimension ().
:type param_dims: list[int]
:param nonlinearity: The nonlinearity to use in the feedforward network such as torch.nn.ReLU(). Note that no
nonlinearity is applied to the final network output, so the output is an unbounded real number.
:type nonlinearity: torch.nn.module
"""
def __init__(
self,
input_dim: int,
hidden_dims: List[int],
param_dims: List[int] = [1, 1],
nonlinearity: torch.nn.Module = torch.nn.ReLU(),
) -> None:
super(DenseNN, self).__init__(
input_dim, 0, hidden_dims, param_dims=param_dims, nonlinearity=nonlinearity
)
[docs] def forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: # type: ignore[override]
return self._forward(x)