Source code for pyro.distributions.transforms.block_autoregressive

# This implementation is adapted in part from https://github.com/nicola-decao/BNAF under the MIT license.
import math

import torch
import torch.nn as nn
from pyro.distributions.torch_transform import TransformModule
from torch.distributions import constraints
import torch.nn.functional as F

from pyro.distributions.util import copy_docs_from
from pyro.distributions.transforms.neural_autoregressive import ELUTransform, LeakyReLUTransform, TanhTransform

eps = 1e-8


def log_matrix_product(A, B):
    """
    Computes the matrix products of two matrices in log-space, returning the result in log-space.
    This is useful for calculating the vector chain rule for Jacobian terms.
    """
    return torch.logsumexp(A.unsqueeze(-1) + B.unsqueeze(-3), dim=-2)


[docs]@copy_docs_from(TransformModule) class BlockAutoregressive(TransformModule): """ An implementation of Block Neural Autoregressive Flow (block-NAF) (De Cao et al., 2019) bijective transform. Block-NAF uses a similar transformation to deep dense NAF, building the autoregressive NN into the structure of the transform, in a sense. Together with :class:`~pyro.distributions.TransformedDistribution` this provides a way to create richer variational approximations. Example usage: >>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10)) >>> naf = BlockAutoregressive(input_dim=10) >>> pyro.module("my_naf", naf) # doctest: +SKIP >>> naf_dist = dist.TransformedDistribution(base_dist, [naf]) >>> naf_dist.sample() # doctest: +SKIP tensor([-0.4071, -0.5030, 0.7924, -0.2366, -0.2387, -0.1417, 0.0868, 0.1389, -0.4629, 0.0986]) The inverse operation is not implemented. This would require numerical inversion, e.g., using a root finding method - a possibility for a future implementation. :param input_dim: The dimensionality of the input and output variables. :type input_dim: int :param hidden_factors: Hidden layer i has hidden_factors[i] hidden units per input dimension. This corresponds to both :math:`a` and :math:`b` in De Cao et al. (2019). The elements of hidden_factors must be integers. :type hidden_factors: list :param activation: Activation function to use. One of 'ELU', 'LeakyReLU', 'sigmoid', or 'tanh'. :type activation: string :param residual: Type of residual connections to use. Choices are "None", "normal" for :math:`\\mathbf{y}+f(\\mathbf{y})`, and "gated" for :math:`\\alpha\\mathbf{y} + (1 - \\alpha\\mathbf{y})` for learnable parameter :math:`\\alpha`. :type residual: string References: Block Neural Autoregressive Flow [arXiv:1904.04676] Nicola De Cao, Ivan Titov, Wilker Aziz """ domain = constraints.real codomain = constraints.real bijective = True event_dim = 1 autoregressive = True def __init__(self, input_dim, hidden_factors=[8, 8], activation='tanh', residual=None): super(BlockAutoregressive, self).__init__(cache_size=1) if any([h < 1 for h in hidden_factors]): raise ValueError('Hidden factors, {}, must all be >= 1'.format(hidden_factors)) if residual not in [None, 'normal', 'gated']: raise ValueError('Invalid value {} for keyword argument "residual"'.format(residual)) # Mix in activation function methods name_to_mixin = { 'ELU': ELUTransform, 'LeakyReLU': LeakyReLUTransform, 'sigmoid': torch.distributions.transforms.SigmoidTransform, 'tanh': TanhTransform} if activation not in name_to_mixin: raise ValueError('Invalid activation function "{}"'.format(activation)) self.T = name_to_mixin[activation]() # Initialize modules for each layer in transform self.residual = residual self.input_dim = input_dim self.layers = nn.ModuleList([MaskedBlockLinear(input_dim, input_dim * hidden_factors[0], input_dim)]) for idx in range(1, len(hidden_factors)): self.layers.append(MaskedBlockLinear( input_dim * hidden_factors[idx - 1], input_dim * hidden_factors[idx], input_dim)) self.layers.append(MaskedBlockLinear(input_dim * hidden_factors[-1], input_dim, input_dim)) self._cached_logDetJ = None if residual == 'gated': self.gate = torch.nn.Parameter(torch.nn.init.normal_(torch.Tensor(1))) def _call(self, x): """ :param x: the input into the bijection :type x: torch.Tensor Invokes the bijection x=>y; in the prototypical context of a :class:`~pyro.distributions.TransformedDistribution` `x` is a sample from the base distribution (or the output of a previous transform) """ y = x for idx in range(len(self.layers)): pre_activation, dy_dx = self.layers[idx](y.unsqueeze(-1)) if idx == 0: y = self.T(pre_activation) J_act = self.T.log_abs_det_jacobian((pre_activation).view( *(list(x.size()) + [-1, 1])), y.view(*(list(x.size()) + [-1, 1]))) logDetJ = dy_dx + J_act elif idx < len(self.layers) - 1: y = self.T(pre_activation) J_act = self.T.log_abs_det_jacobian((pre_activation).view( *(list(x.size()) + [-1, 1])), y.view(*(list(x.size()) + [-1, 1]))) logDetJ = log_matrix_product(dy_dx, logDetJ) + J_act else: y = pre_activation logDetJ = log_matrix_product(dy_dx, logDetJ) self._cached_logDetJ = logDetJ.squeeze(-1).squeeze(-1) if self.residual == 'normal': y = y + x self._cached_logDetJ = F.softplus(self._cached_logDetJ) elif self.residual == 'gated': y = self.gate.sigmoid() * x + (1. - self.gate.sigmoid()) * y term1 = torch.log(self.gate.sigmoid() + eps) log1p_gate = torch.log1p(eps - self.gate.sigmoid()) log_gate = torch.log(self.gate.sigmoid() + eps) term2 = F.softplus(log1p_gate - log_gate + self._cached_logDetJ) self._cached_logDetJ = term1 + term2 return y def _inverse(self, y): """ :param y: the output of the bijection :type y: torch.Tensor Inverts y => x. As noted above, this implementation is incapable of inverting arbitrary values `y`; rather it assumes `y` is the result of a previously computed application of the bijector to some `x` (which was cached on the forward call) """ raise KeyError("BlockAutoregressive object expected to find key in intermediates cache but didn't")
[docs] def log_abs_det_jacobian(self, x, y): """ Calculates the elementwise determinant of the log jacobian """ return self._cached_logDetJ.sum(-1)
class MaskedBlockLinear(torch.nn.Module): """ Module that implements a linear layer with block matrices with positive diagonal blocks. Moreover, it uses Weight Normalization (https://arxiv.org/abs/1602.07868) for stability. """ def __init__(self, in_features, out_features, dim, bias=True): super(MaskedBlockLinear, self).__init__() self.in_features, self.out_features, self.dim = in_features, out_features, dim weight = torch.zeros(out_features, in_features) # Fill in non-zero entries of block weight matrix, going from top # to bottom. for i in range(dim): weight[i * out_features // dim:(i + 1) * out_features // dim, 0:(i + 1) * in_features // dim] = torch.nn.init.xavier_uniform_( torch.Tensor(out_features // dim, (i + 1) * in_features // dim)) self._weight = torch.nn.Parameter(weight) self._diag_weight = torch.nn.Parameter(torch.nn.init.uniform_(torch.Tensor(out_features, 1)).log()) self.bias = torch.nn.Parameter( torch.nn.init.uniform_(torch.Tensor(out_features), -1 / math.sqrt(out_features), 1 / math.sqrt(out_features))) if bias else 0 # Diagonal block mask mask_d = torch.eye(dim).unsqueeze(-1).repeat(1, out_features // dim, in_features // dim).view(out_features, in_features) self.register_buffer('mask_d', mask_d) # Off-diagonal block mask for lower triangular weight matrix mask_o = torch.tril(torch.ones(dim, dim), diagonal=-1).unsqueeze(-1) mask_o = mask_o.repeat(1, out_features // dim, in_features // dim).view(out_features, in_features) self.register_buffer('mask_o', mask_o) def get_weights(self): """ Computes the weight matrix using masks and weight normalization. It also compute the log diagonal blocks of it. """ # Form block weight matrix, making sure it's positive on diagonal! w = torch.exp(self._weight) * self.mask_d + self._weight * self.mask_o # Sum is taken over columns, i.e. one norm per row w_squared_norm = (w ** 2).sum(-1, keepdim=True) # Effect of multiplication and division is that each row is normalized and rescaled w = self._diag_weight.exp() * w / (w_squared_norm.sqrt() + eps) # Taking the effect of weight normalization into account in calculating the log-gradient is straightforward! # Instead of differentiating, e.g. d(W_1x)/dx, we have d(g_1W_1/(W_1^TW_1)^0.5x)/dx, roughly speaking, and # taking the log gives the right hand side below: wpl = self._diag_weight + self._weight - 0.5 * torch.log(w_squared_norm + eps) return w, wpl[self.mask_d.bool()].view(self.dim, self.out_features // self.dim, self.in_features // self.dim) def forward(self, x): """ Parameters ---------- inputs : ``torch.Tensor``, required. The input tensor. grad : ``torch.Tensor``, optional (default = None). The log diagonal block of the partial Jacobian of previous transformations. Returns ------- The output tensor and the log diagonal blocks of the partial log-Jacobian of previous transformations combined with this transformation. """ w, wpl = self.get_weights() return (torch.matmul(w, x) + self.bias.unsqueeze(-1)).squeeze(-1), wpl
[docs]def block_autoregressive(input_dim, **kwargs): """ A helper function to create a :class:`~pyro.distributions.transforms.BlockAutoregressive` object for consistency with other helpers. :param input_dim: Dimension of input variable :type input_dim: int :param hidden_factors: Hidden layer i has hidden_factors[i] hidden units per input dimension. This corresponds to both :math:`a` and :math:`b` in De Cao et al. (2019). The elements of hidden_factors must be integers. :type hidden_factors: list :param activation: Activation function to use. One of 'ELU', 'LeakyReLU', 'sigmoid', or 'tanh'. :type activation: string :param residual: Type of residual connections to use. Choices are "None", "normal" for :math:`\\mathbf{y}+f(\\mathbf{y})`, and "gated" for :math:`\\alpha\\mathbf{y} + (1 - \\alpha\\mathbf{y})` for learnable parameter :math:`\\alpha`. :type residual: string """ return BlockAutoregressive(input_dim, **kwargs)