Source code for pyro.contrib.bnn.hidden_layer

from __future__ import absolute_import, division, print_function

import torch
from torch.distributions.utils import lazy_property
import torch.nn.functional as F

from pyro.contrib.bnn.utils import adjoin_ones_vector
from pyro.distributions.torch_distribution import TorchDistribution

[docs]class HiddenLayer(TorchDistribution): r""" This distribution is a basic building block in a Bayesian neural network. It represents a single hidden layer, i.e. an affine transformation applied to a set of inputs `X` followed by a non-linearity. The uncertainty in the weights is encoded in a Normal variational distribution specified by the parameters `A_scale` and `A_mean`. The so-called 'local reparameterization trick' is used to reduce variance (see reference below). In effect, this means the weights are never sampled directly; instead one samples in pre-activation space (i.e. before the non-linearity is applied). Since the weights are never directly sampled, when this distribution is used within the context of variational inference, care must be taken to correctly scale the KL divergence term that corresponds to the weight matrix. This term is folded into the `log_prob` method of this distributions. In effect, this distribution encodes the following generative process: A ~ Normal(A_mean, A_scale) output ~ non_linearity(AX) :param torch.Tensor X: B x D dimensional mini-batch of inputs :param torch.Tensor A_mean: D x H dimensional specifiying weight mean :param torch.Tensor A_scale: D x H dimensional (diagonal covariance matrix) specifying weight uncertainty :param callable non_linearity: a callable that specifies the non-linearity used. defaults to ReLU. :param float KL_factor: scaling factor for the KL divergence. prototypically this is equal to the size of the mini-batch divided by the size of the whole dataset. defaults to `1.0`. :param A_prior: the prior over the weights is assumed to be normal with mean zero and scale factor `A_prior`. default value is 1.0. :type A_prior: float or torch.Tensor :param bool include_hidden_bias: controls whether the activations should be augmented with a 1, which can be used to incorporate bias terms. defaults to `True`. :param bool weight_space_sampling: controls whether the local reparameterization trick is used. this is only intended to be used for internal testing. defaults to `False`. Reference: Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." Advances in Neural Information Processing Systems. 2015. """ has_rsample = True def __init__(self, X=None, A_mean=None, A_scale=None, non_linearity=F.relu, KL_factor=1.0, A_prior_scale=1.0, include_hidden_bias=True, weight_space_sampling=False): self.X = X self.dim_X = X.size(-1) self.dim_H = A_mean.size(-1) assert A_mean.size(0) == self.dim_X, \ "The dimensions of X and A_mean and A_scale must match accordingly; see documentation" self.A_mean = A_mean self.A_scale = A_scale self.non_linearity = non_linearity assert callable(non_linearity), "non_linearity must be callable" if A_scale.dim() != 2: raise NotImplementedError("A_scale must be 2-dimensional") self.KL_factor = KL_factor self.A_prior_scale = A_prior_scale self.weight_space_sampling = weight_space_sampling self.include_hidden_bias = include_hidden_bias def log_prob(self, value): return -self.KL_factor * self.KL @lazy_property def KL(self): KL_A = torch.pow(self.A_mean / self.A_prior_scale, 2.0).sum() KL_A -= self.dim_X * self.dim_H KL_A += torch.pow(self.A_scale / self.A_prior_scale, 2.0).sum() KL_A -= 2.0 * torch.log(self.A_scale / self.A_prior_scale).sum() return 0.5 * KL_A def rsample(self, sample_shape=torch.Size()): # note: weight space sampling is only meant for testing if self.weight_space_sampling: A = self.A_mean + torch.randn(sample_shape + self.A_scale.shape).type_as(self.A_mean) * self.A_scale activation = torch.matmul(self.X, A) else: _mean = torch.matmul(self.X, self.A_mean) X_sqr = torch.pow(self.X, 2.0).unsqueeze(-1) A_scale_sqr = torch.pow(self.A_scale, 2.0) _std = (X_sqr * A_scale_sqr).sum(-2).sqrt() activation = _mean + torch.randn(sample_shape + _std.shape).type_as(_std) * _std # apply non-linearity activation = self.non_linearity(activation) # add 1 element to activations if self.include_hidden_bias: activation = adjoin_ones_vector(activation) return activation