Source code for pyro.distributions.transforms.householder

import math
import warnings

import torch
import torch.nn as nn
from torch.distributions import constraints

from pyro.distributions.torch_transform import TransformModule
from pyro.distributions.util import copy_docs_from


[docs]@copy_docs_from(TransformModule) class Householder(TransformModule): """ Represents multiple applications of the Householder bijective transformation. A single Householder transformation takes the form, :math:`\\mathbf{y} = (I - 2*\\frac{\\mathbf{u}\\mathbf{u}^T}{||\\mathbf{u}||^2})\\mathbf{x}` where :math:`\\mathbf{x}` are the inputs, :math:`\\mathbf{y}` are the outputs, and the learnable parameters are :math:`\\mathbf{u}\\in\\mathbb{R}^D` for input dimension :math:`D`. The transformation represents the reflection of :math:`\\mathbf{x}` through the plane passing through the origin with normal :math:`\\mathbf{u}`. :math:`D` applications of this transformation are able to transform standard i.i.d. standard Gaussian noise into a Gaussian variable with an arbitrary covariance matrix. With :math:`K<D` transformations, one is able to approximate a full-rank Gaussian distribution using a linear transformation of rank :math:`K`. Together with :class:`~pyro.distributions.TransformedDistribution` this provides a way to create richer variational approximations. Example usage: >>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10)) >>> transform = Householder(10, count_transforms=5) >>> pyro.module("my_transform", p) # doctest: +SKIP >>> flow_dist = dist.TransformedDistribution(base_dist, [transform]) >>> flow_dist.sample() # doctest: +SKIP tensor([-0.4071, -0.5030, 0.7924, -0.2366, -0.2387, -0.1417, 0.0868, 0.1389, -0.4629, 0.0986]) :param input_dim: the dimension of the input (and output) variable. :type input_dim: int :param count_transforms: number of applications of Householder transformation to apply. :type count_transforms: int References: Improving Variational Auto-Encoders using Householder Flow, [arXiv:1611.09630] Tomczak, J. M., & Welling, M. """ domain = constraints.real codomain = constraints.real bijective = True event_dim = 1 volume_preserving = True def __init__(self, input_dim, count_transforms=1): super(Householder, self).__init__(cache_size=1) self.input_dim = input_dim if count_transforms < 1: raise ValueError('Number of Householder transforms, {}, is less than 1!'.format(count_transforms)) elif count_transforms > input_dim: warnings.warn( "Number of Householder transforms, {}, is greater than input dimension {}, which is an \ over-parametrization!".format(count_transforms, input_dim)) self.count_transforms = count_transforms self.u_unnormed = nn.Parameter(torch.Tensor(count_transforms, input_dim)) self.reset_parameters()
[docs] def reset_parameters(self): stdv = 1. / math.sqrt(self.u_unnormed.size(-1)) self.u_unnormed.data.uniform_(-stdv, stdv)
# Construct normalized vectors for Householder transform
[docs] def u(self): norm = torch.norm(self.u_unnormed, p=2, dim=-1, keepdim=True) return torch.div(self.u_unnormed, norm)
def _call(self, x): """ :param x: the input into the bijection :type x: torch.Tensor Invokes the bijection x=>y; in the prototypical context of a :class:`~pyro.distributions.TransformedDistribution` `x` is a sample from the base distribution (or the output of a previous transform) """ y = x u = self.u() for idx in range(self.count_transforms): projection = (u[idx] * y).sum(dim=-1, keepdim=True) * u[idx] y = y - 2. * projection return y def _inverse(self, y): """ :param y: the output of the bijection :type y: torch.Tensor Inverts y => x. The Householder transformation, H, is "involutory," i.e. H^2 = I. If you reflect a point around a plane, then the same operation will reflect it back """ x = y u = self.u() for jdx in reversed(range(self.count_transforms)): # NOTE: Need to apply transforms in reverse order as forward operation! projection = (u[jdx] * x).sum(dim=-1, keepdim=True) * u[jdx] x = x - 2. * projection return x
[docs] def log_abs_det_jacobian(self, x, y): """ Calculates the elementwise determinant of the log jacobian. Householder flow is measure preserving, so :math:`\\log(|detJ|) = 0` """ return torch.zeros(x.size()[:-1], dtype=x.dtype, layout=x.layout, device=x.device)
[docs]def householder(input_dim, count_transforms=None): """ A helper function to create a :class:`~pyro.distributions.transforms.Householder` object for consistency with other helpers. :param input_dim: Dimension of input variable :type input_dim: int :param count_transforms: number of applications of Householder transformation to apply. :type count_transforms: int """ if count_transforms is None: count_transforms = input_dim // 2 + 1 return Householder(input_dim, count_transforms=count_transforms)