# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
from pyro.optim import PyroOptim
from pyro.optim.lr_scheduler import PyroLRScheduler
__all__ = []
# Programmatically load all optimizers from PyTorch.
for _name, _Optim in torch.optim.__dict__.items():
if not isinstance(_Optim, type):
continue
if not issubclass(_Optim, torch.optim.Optimizer):
continue
if _Optim is torch.optim.Optimizer:
continue
if _Optim is torch.optim.LBFGS:
# XXX LBFGS is not supported for SVI yet
continue
_PyroOptim = (
lambda _Optim: lambda optim_args, clip_args=None: PyroOptim(
_Optim, optim_args, clip_args
)
)(_Optim)
_PyroOptim.__name__ = _name
_PyroOptim.__doc__ = "Wraps :class:`torch.optim.{}` with :class:`~pyro.optim.optim.PyroOptim`.".format(
_name
)
locals()[_name] = _PyroOptim
__all__.append(_name)
del _PyroOptim
# Load all schedulers from PyTorch
for _name, _Optim in torch.optim.lr_scheduler.__dict__.items():
if not isinstance(_Optim, type):
continue
if (
not issubclass(_Optim, torch.optim.lr_scheduler._LRScheduler)
and _name != "ReduceLROnPlateau"
):
continue
if _Optim is torch.optim.Optimizer:
continue
_PyroOptim = (
lambda _Optim: lambda optim_args, clip_args=None: PyroLRScheduler(
_Optim, optim_args, clip_args
)
)(_Optim)
_PyroOptim.__name__ = _name
_PyroOptim.__doc__ = (
"Wraps :class:`torch.optim.{}` with ".format(_name)
+ ":class:`~pyro.optim.lr_scheduler.PyroLRScheduler`."
)
locals()[_name] = _PyroOptim
__all__.append(_name)
del _PyroOptim