Bayesian Neural Networks


class HiddenLayer(X=None, A_mean=None, A_scale=None, non_linearity=<function relu>, KL_factor=1.0, A_prior_scale=1.0, include_hidden_bias=True, weight_space_sampling=False)[source]

This distribution is a basic building block in a Bayesian neural network. It represents a single hidden layer, i.e. an affine transformation applied to a set of inputs X followed by a non-linearity. The uncertainty in the weights is encoded in a Normal variational distribution specified by the parameters A_scale and A_mean. The so-called ‘local reparameterization trick’ is used to reduce variance (see reference below). In effect, this means the weights are never sampled directly; instead one samples in pre-activation space (i.e. before the non-linearity is applied). Since the weights are never directly sampled, when this distribution is used within the context of variational inference, care must be taken to correctly scale the KL divergence term that corresponds to the weight matrix. This term is folded into the log_prob method of this distributions.

In effect, this distribution encodes the following generative process:

A ~ Normal(A_mean, A_scale) output ~ non_linearity(AX)

  • X (torch.Tensor) – B x D dimensional mini-batch of inputs
  • A_mean (torch.Tensor) – D x H dimensional specifiying weight mean
  • A_scale (torch.Tensor) – D x H dimensional (diagonal covariance matrix) specifying weight uncertainty
  • non_linearity (callable) – a callable that specifies the non-linearity used. defaults to ReLU.
  • KL_factor (float) – scaling factor for the KL divergence. prototypically this is equal to the size of the mini-batch divided by the size of the whole dataset. defaults to 1.0.
  • A_prior (float or torch.Tensor) – the prior over the weights is assumed to be normal with mean zero and scale factor A_prior. default value is 1.0.
  • include_hidden_bias (bool) – controls whether the activations should be augmented with a 1, which can be used to incorporate bias terms. defaults to True.
  • weight_space_sampling (bool) – controls whether the local reparameterization trick is used. this is only intended to be used for internal testing. defaults to False.


Kingma, Diederik P., Tim Salimans, and Max Welling. “Variational dropout and the local reparameterization trick.” Advances in Neural Information Processing Systems. 2015.