# Source code for pyro.distributions.transforms.planar

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import math

import torch
import torch.nn as nn
from torch.distributions import constraints
import torch.nn.functional as F

from torch.distributions import Transform
from pyro.distributions.conditional import ConditionalTransformModule
from pyro.distributions.torch_transform import TransformModule
from pyro.distributions.util import copy_docs_from
from pyro.nn import DenseNN

@copy_docs_from(Transform)
class ConditionedPlanar(Transform):
domain = constraints.real
codomain = constraints.real
bijective = True
event_dim = 1

def __init__(self, bias=None, u=None, w=None):
super().__init__(cache_size=1)
self.bias = bias
self.u = u
self.w = w
self._cached_logDetJ = None

# This method ensures that torch(u_hat, w) > -1, required for invertibility
def u_hat(self, u, w):
alpha = torch.matmul(u.unsqueeze(-2), w.unsqueeze(-1)).squeeze(-1)
a_prime = -1 + F.softplus(alpha)
return u + (a_prime - alpha) * w.div(w.pow(2).sum(dim=-1, keepdim=True))

def _call(self, x):
"""
:param x: the input into the bijection
:type x: torch.Tensor
Invokes the bijection x => y; in the prototypical context of a
:class:~pyro.distributions.TransformedDistribution x is a sample from the base distribution (or the output
of a previous transform)
"""

# x ~ (batch_size, dim_size, 1)
# w ~ (batch_size, 1, dim_size)
# bias ~ (batch_size, 1)
act = torch.tanh(torch.matmul(self.w.unsqueeze(-2), x.unsqueeze(-1)).squeeze(-1) + self.bias)
u_hat = self.u_hat(self.u, self.w)
y = x + u_hat * act

psi_z = (1. - act.pow(2)) * self.w
self._cached_logDetJ = torch.log(
torch.abs(1 + torch.matmul(psi_z.unsqueeze(-2), u_hat.unsqueeze(-1)).squeeze(-1).squeeze(-1)))

return y

def _inverse(self, y):
"""
:param y: the output of the bijection
:type y: torch.Tensor
Inverts y => x. As noted above, this implementation is incapable of inverting arbitrary values
y; rather it assumes y is the result of a previously computed application of the bijector
to some x (which was cached on the forward call)
"""

raise KeyError("ConditionalPlanar object expected to find key in intermediates cache but didn't")

def log_abs_det_jacobian(self, x, y):
"""
Calculates the elementwise determinant of the log Jacobian
"""
return self._cached_logDetJ

[docs]@copy_docs_from(ConditionedPlanar)
class Planar(ConditionedPlanar, TransformModule):
"""
A 'planar' bijective transform with equation,

:math:\\mathbf{y} = \\mathbf{x} + \\mathbf{u}\\tanh(\\mathbf{w}^T\\mathbf{z}+b)

where :math:\\mathbf{x} are the inputs, :math:\\mathbf{y} are the outputs, and the learnable parameters
are :math:b\\in\\mathbb{R}, :math:\\mathbf{u}\\in\\mathbb{R}^D, :math:\\mathbf{w}\\in\\mathbb{R}^D for input
dimension :math:D. For this to be an invertible transformation, the condition
:math:\\mathbf{w}^T\\mathbf{u}>-1 is enforced.

Together with :class:~pyro.distributions.TransformedDistribution this provides a way to create richer
variational approximations.

Example usage:

>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
>>> transform = Planar(10)
>>> pyro.module("my_transform", transform)  # doctest: +SKIP
>>> flow_dist = dist.TransformedDistribution(base_dist, [transform])
>>> flow_dist.sample()  # doctest: +SKIP
tensor([-0.4071, -0.5030,  0.7924, -0.2366, -0.2387, -0.1417,  0.0868,
0.1389, -0.4629,  0.0986])

The inverse of this transform does not possess an analytical solution and is left unimplemented. However,
the inverse is cached when the forward operation is called during sampling, and so samples drawn using
the planar transform can be scored.

:param input_dim: the dimension of the input (and output) variable.
:type input_dim: int

References:

Variational Inference with Normalizing Flows [arXiv:1505.05770]
Danilo Jimenez Rezende, Shakir Mohamed

"""

domain = constraints.real
codomain = constraints.real
bijective = True
event_dim = 1

def __init__(self, input_dim):
super().__init__()

self.bias = nn.Parameter(torch.Tensor(1,))
self.u = nn.Parameter(torch.Tensor(input_dim,))
self.w = nn.Parameter(torch.Tensor(input_dim,))
self.input_dim = input_dim
self.reset_parameters()

[docs]    def reset_parameters(self):
stdv = 1. / math.sqrt(self.u.size(0))
self.w.data.uniform_(-stdv, stdv)
self.u.data.uniform_(-stdv, stdv)
self.bias.data.zero_()

[docs]@copy_docs_from(ConditionalTransformModule)
class ConditionalPlanar(ConditionalTransformModule):
"""
A conditional 'planar' bijective transform using the equation,

:math:\\mathbf{y} = \\mathbf{x} + \\mathbf{u}\\tanh(\\mathbf{w}^T\\mathbf{z}+b)

where :math:\\mathbf{x} are the inputs with dimension :math:D, :math:\\mathbf{y} are the outputs,
and the pseudo-parameters :math:b\\in\\mathbb{R}, :math:\\mathbf{u}\\in\\mathbb{R}^D, and
:math:\\mathbf{w}\\in\\mathbb{R}^D are the output of a function, e.g. a NN, with input
:math:z\\in\\mathbb{R}^{M} representing the context variable to condition on. For this to be an
invertible transformation, the condition :math:\\mathbf{w}^T\\mathbf{u}>-1 is enforced.

Together with :class:~pyro.distributions.ConditionalTransformedDistribution this provides a way to create
richer variational approximations.

Example usage:

>>> from pyro.nn.dense_nn import DenseNN
>>> input_dim = 10
>>> context_dim = 5
>>> batch_size = 3
>>> base_dist = dist.Normal(torch.zeros(input_dim), torch.ones(input_dim))
>>> hypernet = DenseNN(context_dim, [50, 50], param_dims=[1, input_dim, input_dim])
>>> transform = ConditionalPlanar(hypernet)
>>> z = torch.rand(batch_size, context_dim)
>>> flow_dist = dist.ConditionalTransformedDistribution(base_dist, [transform]).condition(z)
>>> flow_dist.sample(sample_shape=torch.Size([batch_size])) # doctest: +SKIP

The inverse of this transform does not possess an analytical solution and is left unimplemented. However,
the inverse is cached when the forward operation is called during sampling, and so samples drawn using
the planar transform can be scored.

:param nn: a function inputting the context variable and outputting a triplet of real-valued parameters
of dimensions :math:(1, D, D).
:type nn: callable

References:
Variational Inference with Normalizing Flows [arXiv:1505.05770]
Danilo Jimenez Rezende, Shakir Mohamed

"""

domain = constraints.real
codomain = constraints.real
bijective = True
event_dim = 1

def __init__(self, nn):
super().__init__()
self.nn = nn

[docs]    def condition(self, context):
bias, u, w = self.nn(context)
return ConditionedPlanar(bias, u, w)

[docs]def planar(input_dim):
"""
A helper function to create a :class:~pyro.distributions.transforms.Planar object for consistency with other
helpers.

:param input_dim: Dimension of input variable
:type input_dim: int

"""

return Planar(input_dim)

[docs]def conditional_planar(input_dim, context_dim, hidden_dims=None):
"""
A helper function to create a :class:~pyro.distributions.transforms.ConditionalPlanar object that takes care of
constructing a dense network with the correct input/output dimensions.

:param input_dim: Dimension of input variable
:type input_dim: int
:param context_dim: Dimension of context variable
:type context_dim: int
:param hidden_dims: The desired hidden dimensions of the dense network. Defaults
to using [input_dim * 10, input_dim * 10]
:type hidden_dims: list[int]

"""

if hidden_dims is None:
hidden_dims = [input_dim * 10, input_dim * 10]
hypernet = DenseNN(context_dim, hidden_dims, param_dims=[1, input_dim, input_dim])
return ConditionalPlanar(hypernet)