import torch
from torch.optim.optimizer import Optimizer

"""
Implements a mash-up of the Adagrad algorithm and RMSProp. For the precise
update equation see equations 10 and 11 in reference [1].

References:
[1] 'Automatic Differentiation Variational Inference', Alp Kucukelbir,
Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
URL: https://arxiv.org/abs/1603.00788
[2] 'Lecture 6.5 RmsProp: Divide the gradient by a running average
of its recent magnitude', Tieleman, T. and Hinton, G.,
COURSERA: Neural Networks for Machine Learning.
Duchi, John, Hazan, E and Singer, Y.

Arguments:

:param params: iterable of parameters to optimize or dicts defining parameter groups
:param eta: sets the step size scale (optional; default: 1.0)
:type eta: float
:param t:  t, optional): momentum parameter (optional; default: 0.1)
:type t: float
:param delta: modulates the exponent that controls how the step size scales (optional: default: 1e-16)
:type delta: float
"""

def __init__(self, params, eta=1.0, delta=1.0e-16, t=0.1):
defaults = dict(eta=eta, delta=delta, t=t)

for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'] = 0
state['sum'] = torch.zeros_like(p.data)

[docs]    def share_memory(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['sum'].share_memory_()

[docs]    def step(self, closure=None):
"""
Performs a single optimization step.

:param closure: A (optional) closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
continue

raise NotImplementedError

state = self.state[p]
state['step'] += 1
if state['step'] == 1:
# if first step, initialize variance bit to grad^2