# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Optional
import torch
from torch.optim.optimizer import Optimizer
[docs]class AdagradRMSProp(Optimizer):
"""
Implements a mash-up of the Adagrad algorithm and RMSProp. For the precise
update equation see equations 10 and 11 in reference [1].
References:
[1] 'Automatic Differentiation Variational Inference', Alp Kucukelbir,
Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
URL: https://arxiv.org/abs/1603.00788
[2] 'Lecture 6.5 RmsProp: Divide the gradient by a running average
of its recent magnitude', Tieleman, T. and Hinton, G.,
COURSERA: Neural Networks for Machine Learning.
[3] 'Adaptive subgradient methods for online learning and stochastic optimization',
Duchi, John, Hazan, E and Singer, Y.
Arguments:
:param params: iterable of parameters to optimize or dicts defining parameter groups
:param eta: sets the step size scale (optional; default: 1.0)
:type eta: float
:param t: t, optional): momentum parameter (optional; default: 0.1)
:type t: float
:param delta: modulates the exponent that controls how the step size scales (optional: default: 1e-16)
:type delta: float
"""
def __init__(
self, params, eta: float = 1.0, delta: float = 1.0e-16, t: float = 0.1
):
defaults = dict(eta=eta, delta=delta, t=t)
super().__init__(params, defaults)
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
state["step"] = 0
state["sum"] = torch.zeros_like(p.data)
[docs] def share_memory(self) -> None:
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
state["sum"].share_memory_()
[docs] def step(self, closure: Optional[Callable] = None) -> Optional[Any]:
"""
Performs a single optimization step.
:param closure: A (optional) closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise NotImplementedError
state = self.state[p]
state["step"] += 1
if state["step"] == 1:
# if first step, initialize variance bit to grad^2
state["sum"] = grad * grad
else:
state["sum"] *= 1.0 - group["t"]
state["sum"] += group["t"] * grad * grad
lr = group["eta"] * (state["step"] ** (-0.5 + group["delta"]))
std = state["sum"].sqrt()
p.data.addcdiv_(grad, 1.0 + std, value=-lr)
return loss