Source code for pyro.optim.pytorch_optimizers

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch

from pyro.optim import PyroOptim
from pyro.optim.lr_scheduler import PyroLRScheduler

__all__ = []
# Programmatically load all optimizers from PyTorch.
for _name, _Optim in torch.optim.__dict__.items():
    if not isinstance(_Optim, type):
        continue
    if not issubclass(_Optim, torch.optim.Optimizer):
        continue
    if _Optim is torch.optim.Optimizer:
        continue
    if _Optim is torch.optim.LBFGS:
        # XXX LBFGS is not supported for SVI yet
        continue

    _PyroOptim = (
        lambda _Optim: lambda optim_args, clip_args=None: PyroOptim(
            _Optim, optim_args, clip_args
        )
    )(_Optim)
    _PyroOptim.__name__ = _name
    _PyroOptim.__doc__ = "Wraps :class:`torch.optim.{}` with :class:`~pyro.optim.optim.PyroOptim`.".format(
        _name
    )

    locals()[_name] = _PyroOptim
    __all__.append(_name)
    del _PyroOptim

# Load all schedulers from PyTorch
for _name, _Optim in torch.optim.lr_scheduler.__dict__.items():
    if not isinstance(_Optim, type):
        continue
    if (
        not issubclass(_Optim, torch.optim.lr_scheduler._LRScheduler)
        and _name != "ReduceLROnPlateau"
    ):
        continue
    if _Optim is torch.optim.Optimizer:
        continue

    _PyroOptim = (
        lambda _Optim: lambda optim_args, clip_args=None: PyroLRScheduler(
            _Optim, optim_args, clip_args
        )
    )(_Optim)
    _PyroOptim.__name__ = _name
    _PyroOptim.__doc__ = (
        "Wraps :class:`torch.optim.{}` with ".format(_name)
        + ":class:`~pyro.optim.lr_scheduler.PyroLRScheduler`."
    )

    locals()[_name] = _PyroOptim
    __all__.append(_name)
    del _PyroOptim