# Source code for pyro.ops.newton

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import torch

from pyro.ops.linalg import eig_3d, rinverse
from pyro.util import warn_if_nan

"""
Performs a Newton update step to minimize loss on a batch of variables,
optionally constraining to a trust region .

This is especially usful because the final solution of newton iteration
is differentiable wrt the inputs, even when all but the final x is
detached, due to this method's quadratic convergence . loss must be
twice-differentiable as a function of x. If loss is 2+d-times
differentiable, then the return value of this function is d-times
differentiable.

When loss is interpreted as a negative log probability density, then
the return values mode,cov of this function can be used to construct a
Laplace approximation MultivariateNormal(mode,cov).

.. warning:: Take care to detach the result of this function when used in
an optimization loop. If you forget to detach the result of this
function during optimization, then backprop will propagate through
the entire iteration process, and worse will compute two extra
derivatives for each step.

Example use inside a loop::

x = torch.zeros(1000, 2)  # arbitrary initial value
for step in range(100):
x = x.detach()          # block gradients through previous steps
x.requires_grad = True  # ensure loss is differentiable wrt x
loss = my_loss_function(x)
# the final x is still differentiable

 Yuan, Ya-xiang. Iciam. Vol. 99. 2000.
"A review of trust region algorithms for optimization."
ftp://ftp.cc.ac.cn/pub/yyx/papers/p995.pdf
 Christianson, Bruce. Optimization Methods and Software 3.4 (1994)
"Reverse accumulation and attractive fixed points."
http://uhra.herts.ac.uk/bitstream/handle/2299/4338/903839.pdf

:param torch.Tensor loss: A scalar function of x to be minimized.
:param torch.Tensor x: A dependent variable of shape (N, D)
where N is the batch size and D is a small number.
updated value mode of this function will be within
trust_radius of the input x.
:return: A pair (mode, cov) where mode is an updated tensor
of the same shape as the original value x, and cov is an
esitmate of the covariance DxD matrix with
cov.shape == x.shape[:-1] + (D,D).
:rtype: tuple
"""
if x.dim() < 1:
raise ValueError(
"Expected x to have at least one dimension, actual shape {}".format(x.shape)
)
dim = x.shape[-1]
if dim == 1:
elif dim == 2:
elif dim == 3:
else:
raise NotImplementedError("newton_step_nd is not implemented")

"""
Performs a Newton update step to minimize loss on a batch of 1-dimensional
variables, optionally regularizing to constrain to a trust region.

See :func:newton_step for details.

:param torch.Tensor loss: A scalar function of x to be minimized.
:param torch.Tensor x: A dependent variable with rightmost size of 1.
updated value mode of this function will be within
trust_radius of the input x.
:return: A pair (mode, cov) where mode is an updated tensor
of the same shape as the original value x, and cov is an
esitmate of the covariance 1x1 matrix with
cov.shape == x.shape[:-1] + (1,1).
:rtype: tuple
"""
if loss.shape != ():
raise ValueError(
"Expected loss to be a scalar, actual shape {}".format(loss.shape)
)
if x.dim() < 1 or x.shape[-1] != 1:
raise ValueError(
"Expected x to have rightmost size 1, actual shape {}".format(x.shape)
)

# compute derivatives
warn_if_nan(g, "g")
warn_if_nan(H, "H")
Hinv = H.clamp(min=1e-8).reciprocal()
dx = -g * Hinv
dx[~(dx == dx)] = 0

# apply update
x_new = x.detach() + dx
assert x_new.shape == x.shape
return x_new, Hinv.unsqueeze(-1)

"""
Performs a Newton update step to minimize loss on a batch of 2-dimensional
variables, optionally regularizing to constrain to a trust region.

See :func:newton_step for details.

:param torch.Tensor loss: A scalar function of x to be minimized.
:param torch.Tensor x: A dependent variable with rightmost size of 2.
updated value mode of this function will be within
trust_radius of the input x.
:return: A pair (mode, cov) where mode is an updated tensor
of the same shape as the original value x, and cov is an
esitmate of the covariance 2x2 matrix with
cov.shape == x.shape[:-1] + (2,2).
:rtype: tuple
"""
if loss.shape != ():
raise ValueError(
"Expected loss to be a scalar, actual shape {}".format(loss.shape)
)
if x.dim() < 1 or x.shape[-1] != 2:
raise ValueError(
"Expected x to have rightmost size 2, actual shape {}".format(x.shape)
)

# compute derivatives
H = torch.stack(
[
],
-1,
)
assert g.shape[-1:] == (2,)
assert H.shape[-2:] == (2, 2)
warn_if_nan(g, "g")
warn_if_nan(H, "H")

# regularize to keep update within ball of given trust_radius
detH = H[..., 0, 0] * H[..., 1, 1] - H[..., 0, 1] * H[..., 1, 0]
mean_eig = (H[..., 0, 0] + H[..., 1, 1]) / 2
min_eig = mean_eig - (mean_eig ** 2 - detH).clamp(min=0).sqrt()
regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius - min_eig).clamp_(
min=1e-8
)
warn_if_nan(regularizer, "regularizer")
H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye(
2, dtype=H.dtype, device=H.device
)

# compute newton update
Hinv = rinverse(H, sym=True)
warn_if_nan(Hinv, "Hinv")

# apply update
x_new = x.detach() - (Hinv * g.unsqueeze(-2)).sum(-1)
assert x_new.shape == x.shape
return x_new, Hinv

"""
Performs a Newton update step to minimize loss on a batch of 3-dimensional
variables, optionally regularizing to constrain to a trust region.

See :func:newton_step for details.

:param torch.Tensor loss: A scalar function of x to be minimized.
:param torch.Tensor x: A dependent variable with rightmost size of 2.
updated value mode of this function will be within
trust_radius of the input x.
:return: A pair (mode, cov) where mode is an updated tensor
of the same shape as the original value x, and cov is an
esitmate of the covariance 3x3 matrix with
cov.shape == x.shape[:-1] + (3,3).
:rtype: tuple
"""
if loss.shape != ():
raise ValueError(
"Expected loss to be a scalar, actual shape {}".format(loss.shape)
)
if x.dim() < 1 or x.shape[-1] != 3:
raise ValueError(
"Expected x to have rightmost size 3, actual shape {}".format(x.shape)
)

# compute derivatives
H = torch.stack(
[
],
-1,
)
assert g.shape[-1:] == (3,)
assert H.shape[-2:] == (3, 3)
warn_if_nan(g, "g")
warn_if_nan(H, "H")

# regularize to keep update within ball of given trust_radius
# calculate eigenvalues of symmetric matrix
min_eig, _, _ = eig_3d(H)
regularizer = (g.pow(2).sum(-1).sqrt() / trust_radius - min_eig).clamp_(
min=1e-8
)
warn_if_nan(regularizer, "regularizer")
H = H + regularizer.unsqueeze(-1).unsqueeze(-1) * torch.eye(
3, dtype=H.dtype, device=H.device
)

# compute newton update
Hinv = rinverse(H, sym=True)
warn_if_nan(Hinv, "Hinv")

# apply update
x_new = x.detach() - (Hinv * g.unsqueeze(-2)).sum(-1)
assert x_new.shape == x.shape, "{} {}".format(x_new.shape, x.shape)
return x_new, Hinv