Source code for pyro.nn.auto_reg_nn

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import warnings

import torch
import torch.nn as nn
from torch.nn import functional as F

"""
Samples the indices assigned to hidden units during the construction of MADE masks

:param input_dim: the dimensionality of the input variable
:type input_dim: int
:param hidden_dim: the dimensionality of the hidden layer
:type hidden_dim: int
:param simple: True to space fractional indices by rounding to nearest int, false round randomly
:type simple: bool
"""
indices = torch.linspace(1, input_dim, steps=hidden_dim, device='cpu').to(torch.Tensor().device)
if simple:
# Simple procedure tries to space fractional indices evenly by rounding to nearest int
else:
# "Non-simple" procedure creates fractional indices evenly then rounds at random
ints = indices.floor()
ints += torch.bernoulli(indices - ints)
return ints

def create_mask(input_dim, context_dim, hidden_dims, permutation, output_dim_multiplier):
"""

:param input_dim: the dimensionality of the input variable
:type input_dim: int
:param context_dim: the dimensionality of the variable that is conditioned on (for conditional densities)
:type context_dim: int
:param hidden_dims: the dimensionality of the hidden layers(s)
:type hidden_dims: list[int]
:param permutation: the order of the input variables
:type permutation: torch.LongTensor
:param output_dim_multiplier: tiles the output (e.g. for when a separate mean and scale parameter are desired)
:type output_dim_multiplier: int
"""
# Create mask indices for input, hidden layers, and final layer
# We use 0 to refer to the elements of the variable being conditioned on,
# and range(1:(D_latent+1)) for the input variable
var_index = torch.empty(permutation.shape, dtype=torch.get_default_dtype())
var_index[permutation] = torch.arange(input_dim, dtype=torch.get_default_dtype())

# Create the indices that are assigned to the neurons
input_indices = torch.cat((torch.zeros(context_dim), 1 + var_index))

# For conditional MADE, introduce a 0 index that all the conditioned variables are connected to
# as per Paige and Wood (2016) (see below)
if context_dim > 0:
hidden_indices = [sample_mask_indices(input_dim, h) - 1 for h in hidden_dims]
else:
hidden_indices = [sample_mask_indices(input_dim - 1, h) for h in hidden_dims]

output_indices = (var_index + 1).repeat(output_dim_multiplier)

# Create mask from input to output for the skips connections

# Create mask from input to first hidden layer, and between subsequent hidden layers
for i in range(1, len(hidden_dims)):

# Create mask from last hidden layer to output layer

"""
A linear mapping with a given mask on the weights (arbitrary bias)

:param in_features: the number of input features
:type in_features: int
:param out_features: the number of output features
:type out_features: int
:param mask: the mask to apply to the in_features x out_features weight matrix
:param bias: whether or not MaskedLinear should include a bias term. defaults to True
:type bias: bool
"""

def __init__(self, in_features, out_features, mask, bias=True):
super().__init__(in_features, out_features, bias)

def forward(self, _input):
"""
the forward method that does the masked linear computation and returns the result
"""