Source code for pyro.optim.horovod

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import List, Union, ValuesView

from torch.optim import Optimizer

import pyro

from .optim import PyroOptim

[docs]class HorovodOptimizer(PyroOptim): r""" Distributed wrapper for a :class:`~pyro.optim.optim.PyroOptim` optimizer. This class wraps a ``PyroOptim`` object similar to the way :func:`horovod.torch.DistributedOptimizer` wraps a :class:`torch.optim.Optimizer`. .. note:: This requires :mod:`horovod.torch` to be installed, e.g. via ``pip install pyro[horovod]``. For details see :param: A Pyro optimizer instance. :type pyro_optim: ~pyro.optim.optim.PyroOptim :param \*\*horovod_kwargs: Extra parameters passed to :func:`horovod.torch.DistributedOptimizer`. """ def __init__(self, pyro_optim: PyroOptim, **horovod_kwargs): param_name = pyro.get_param_store().param_name def optim_constructor(params, **pt_kwargs) -> Optimizer: import horovod.torch as hvd # type: ignore pt_optim = pyro_optim.pt_optim_constructor(params, **pt_kwargs) # type: ignore named_parameters = [(param_name(p), p) for p in params] hvd_optim = hvd.DistributedOptimizer( pt_optim, named_parameters=named_parameters, **horovod_kwargs, ) return hvd_optim # type: ignore super().__init__( optim_constructor, pyro_optim.pt_optim_args, pyro_optim.pt_clip_args )
[docs] def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -> None: # Sort by name to ensure deterministic processing order. params = sorted(params, key=pyro.get_param_store().param_name) super().__call__(params, *args, **kwargs)