Source code for pyro.optim.adagrad_rmsprop

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.optim.optimizer import Optimizer

[docs]class AdagradRMSProp(Optimizer): """ Implements a mash-up of the Adagrad algorithm and RMSProp. For the precise update equation see equations 10 and 11 in reference [1]. References: [1] 'Automatic Differentiation Variational Inference', Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei URL: [2] 'Lecture 6.5 RmsProp: Divide the gradient by a running average of its recent magnitude', Tieleman, T. and Hinton, G., COURSERA: Neural Networks for Machine Learning. [3] 'Adaptive subgradient methods for online learning and stochastic optimization', Duchi, John, Hazan, E and Singer, Y. Arguments: :param params: iterable of parameters to optimize or dicts defining parameter groups :param eta: sets the step size scale (optional; default: 1.0) :type eta: float :param t: t, optional): momentum parameter (optional; default: 0.1) :type t: float :param delta: modulates the exponent that controls how the step size scales (optional: default: 1e-16) :type delta: float """ def __init__(self, params, eta=1.0, delta=1.0e-16, t=0.1): defaults = dict(eta=eta, delta=delta, t=t) super().__init__(params, defaults) for group in self.param_groups: for p in group['params']: state = self.state[p] state['step'] = 0 state['sum'] = torch.zeros_like(
[docs] def share_memory(self): for group in self.param_groups: for p in group['params']: state = self.state[p] state['sum'].share_memory_()
[docs] def step(self, closure=None): """ Performs a single optimization step. :param closure: A (optional) closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = if grad.is_sparse: raise NotImplementedError state = self.state[p] state['step'] += 1 if state['step'] == 1: # if first step, initialize variance bit to grad^2 state['sum'] = grad * grad else: state['sum'] *= (1.0 - group['t']) state['sum'] += group['t'] * grad * grad lr = group['eta'] * (state['step'] ** (-0.5 + group['delta'])) std = state['sum'].sqrt(), 1.0 + std, value=-lr) return loss