Source code for pyro.optim.adagrad_rmsprop

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Optional

import torch
from torch.optim.optimizer import Optimizer

[docs]class AdagradRMSProp(Optimizer): """ Implements a mash-up of the Adagrad algorithm and RMSProp. For the precise update equation see equations 10 and 11 in reference [1]. References: [1] 'Automatic Differentiation Variational Inference', Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei URL: [2] 'Lecture 6.5 RmsProp: Divide the gradient by a running average of its recent magnitude', Tieleman, T. and Hinton, G., COURSERA: Neural Networks for Machine Learning. [3] 'Adaptive subgradient methods for online learning and stochastic optimization', Duchi, John, Hazan, E and Singer, Y. Arguments: :param params: iterable of parameters to optimize or dicts defining parameter groups :param eta: sets the step size scale (optional; default: 1.0) :type eta: float :param t: t, optional): momentum parameter (optional; default: 0.1) :type t: float :param delta: modulates the exponent that controls how the step size scales (optional: default: 1e-16) :type delta: float """ def __init__( self, params, eta: float = 1.0, delta: float = 1.0e-16, t: float = 0.1 ): defaults = dict(eta=eta, delta=delta, t=t) super().__init__(params, defaults) for group in self.param_groups: for p in group["params"]: state = self.state[p] state["step"] = 0 state["sum"] = torch.zeros_like(
[docs] def share_memory(self) -> None: for group in self.param_groups: for p in group["params"]: state = self.state[p] state["sum"].share_memory_()
[docs] def step(self, closure: Optional[Callable] = None) -> Optional[Any]: """ Performs a single optimization step. :param closure: A (optional) closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = if grad.is_sparse: raise NotImplementedError state = self.state[p] state["step"] += 1 if state["step"] == 1: # if first step, initialize variance bit to grad^2 state["sum"] = grad * grad else: state["sum"] *= 1.0 - group["t"] state["sum"] += group["t"] * grad * grad lr = group["eta"] * (state["step"] ** (-0.5 + group["delta"])) std = state["sum"].sqrt(), 1.0 + std, value=-lr) return loss