Source code for pyro.optim.lr_scheduler

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Iterable, List, Optional, Union, ValuesView

from torch import Tensor

from pyro.optim.optim import PyroOptim


[docs]class PyroLRScheduler(PyroOptim): """ A wrapper for :class:`~torch.optim.lr_scheduler` objects that adjusts learning rates for dynamically generated parameters. :param scheduler_constructor: a :class:`~torch.optim.lr_scheduler` :param optim_args: a dictionary of learning arguments for the optimizer or a callable that returns such dictionaries. must contain the key 'optimizer' with pytorch optimizer value :param clip_args: a dictionary of clip_norm and/or clip_value args or a callable that returns such dictionaries. Example:: optimizer = torch.optim.SGD scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1}) svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO()) for i in range(epochs): for minibatch in DataLoader(dataset, batch_size): svi.step(minibatch) scheduler.step() """ def __init__( self, scheduler_constructor, optim_args: Union[Dict], clip_args: Optional[Union[Dict]] = None, ): # pytorch scheduler self.pt_scheduler_constructor = scheduler_constructor # torch optimizer pt_optim_constructor = optim_args.pop("optimizer") # kwargs for the torch optimizer optim_kwargs = optim_args.pop("optim_args") self.kwargs = optim_args super().__init__(pt_optim_constructor, optim_kwargs, clip_args)
[docs] def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -> None: super().__call__(params, *args, **kwargs)
def _get_optim( self, params: Union[Tensor, Iterable[Tensor], Iterable[Dict[Any, Any]]] ): optim = super()._get_optim(params) return self.pt_scheduler_constructor(optim, **self.kwargs)
[docs] def step(self, *args, **kwargs) -> None: """ Takes the same arguments as the PyTorch scheduler (e.g. optional ``loss`` for ``ReduceLROnPlateau``) """ for scheduler in self.optim_objs.values(): scheduler.step(*args, **kwargs)