Optimization¶
The module pyro.optim provides support for optimization in Pyro. In particular it provides PyroOptim, which is used to wrap PyTorch optimizers and manage optimizers for dynamically generated parameters (see the tutorial SVI Part I for a discussion). Any custom optimization algorithms are also to be found here.
Pyro Optimizers¶
- is_scheduler(optimizer) bool [source]¶
Helper method to determine whether a PyTorch object is either a PyTorch optimizer (return false) or a optimizer wrapped in an LRScheduler e.g. a
ReduceLROnPlateau
or subclasses of_LRScheduler
(return true).
- class PyroOptim(optim_constructor: Union[Callable, torch.optim.optimizer.Optimizer, Type[torch.optim.optimizer.Optimizer]], optim_args: Union[Dict, Callable[[...], Dict]], clip_args: Optional[Union[Dict, Callable[[...], Dict]]] = None)[source]¶
Bases:
object
A wrapper for torch.optim.Optimizer objects that helps with managing dynamically generated parameters.
- Parameters
optim_constructor – a torch.optim.Optimizer
optim_args – a dictionary of learning arguments for the optimizer or a callable that returns such dictionaries
clip_args – a dictionary of clip_norm and/or clip_value args or a callable that returns such dictionaries
- __call__(params: Union[List, ValuesView], *args, **kwargs) None [source]¶
- Parameters
params (an iterable of strings) – a list of parameters
Do an optimization step for each param in params. If a given param has never been seen before, initialize an optimizer for it.
- get_state() Dict [source]¶
Get state associated with all the optimizers in the form of a dictionary with key-value pairs (parameter name, optim state dicts)
- set_state(state_dict: Dict) None [source]¶
Set the state associated with all the optimizers using the state obtained from a previous call to get_state()
- AdagradRMSProp(optim_args: Dict) pyro.optim.optim.PyroOptim [source]¶
Wraps
pyro.optim.adagrad_rmsprop.AdagradRMSProp
withPyroOptim
.
- ClippedAdam(optim_args: Dict) pyro.optim.optim.PyroOptim [source]¶
Wraps
pyro.optim.clipped_adam.ClippedAdam
withPyroOptim
.
- DCTAdam(optim_args: Dict) pyro.optim.optim.PyroOptim [source]¶
Wraps
pyro.optim.dct_adam.DCTAdam
withPyroOptim
.
- class PyroLRScheduler(scheduler_constructor, optim_args: Dict, clip_args: Optional[Dict] = None)[source]¶
Bases:
pyro.optim.optim.PyroOptim
A wrapper for
lr_scheduler
objects that adjusts learning rates for dynamically generated parameters.- Parameters
scheduler_constructor – a
lr_scheduler
optim_args – a dictionary of learning arguments for the optimizer or a callable that returns such dictionaries. must contain the key ‘optimizer’ with pytorch optimizer value
clip_args – a dictionary of clip_norm and/or clip_value args or a callable that returns such dictionaries.
Example:
optimizer = torch.optim.SGD scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1}) svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO()) for i in range(epochs): for minibatch in DataLoader(dataset, batch_size): svi.step(minibatch) scheduler.step()
- class AdagradRMSProp(params, eta: float = 1.0, delta: float = 1e-16, t: float = 0.1)[source]¶
Bases:
torch.optim.optimizer.Optimizer
Implements a mash-up of the Adagrad algorithm and RMSProp. For the precise update equation see equations 10 and 11 in reference [1].
References: [1] ‘Automatic Differentiation Variational Inference’, Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei URL: https://arxiv.org/abs/1603.00788 [2] ‘Lecture 6.5 RmsProp: Divide the gradient by a running average of its recent magnitude’, Tieleman, T. and Hinton, G., COURSERA: Neural Networks for Machine Learning. [3] ‘Adaptive subgradient methods for online learning and stochastic optimization’, Duchi, John, Hazan, E and Singer, Y.
Arguments:
- Parameters
params – iterable of parameters to optimize or dicts defining parameter groups
eta (float) – sets the step size scale (optional; default: 1.0)
t (float) – t, optional): momentum parameter (optional; default: 0.1)
delta (float) – modulates the exponent that controls how the step size scales (optional: default: 1e-16)
- class ClippedAdam(params, lr: float = 0.001, betas: Tuple = (0.9, 0.999), eps: float = 1e-08, weight_decay=0, clip_norm: float = 10.0, lrd: float = 1.0)[source]¶
Bases:
torch.optim.optimizer.Optimizer
- Parameters
params – iterable of parameters to optimize or dicts defining parameter groups
lr – learning rate (default: 1e-3)
betas (Tuple) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
eps – term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay – weight decay (L2 penalty) (default: 0)
clip_norm – magnitude of norm to which gradients are clipped (default: 10.0)
lrd – rate at which learning rate decays (default: 1.0)
Small modification to the Adam algorithm implemented in torch.optim.Adam to include gradient clipping and learning rate decay.
Reference
A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980
- class HorovodOptimizer(pyro_optim: pyro.optim.optim.PyroOptim, **horovod_kwargs)[source]¶
Bases:
pyro.optim.optim.PyroOptim
Distributed wrapper for a
PyroOptim
optimizer.This class wraps a
PyroOptim
object similar to the wayhorovod.torch.DistributedOptimizer()
wraps atorch.optim.Optimizer
.Note
This requires
horovod.torch
to be installed, e.g. viapip install pyro[horovod]
. For details see https://horovod.readthedocs.io/en/stable/install.html- Param
A Pyro optimizer instance.
- Parameters
**horovod_kwargs – Extra parameters passed to
horovod.torch.DistributedOptimizer()
.
PyTorch Optimizers¶
- Adadelta(optim_args, clip_args=None)¶
Wraps
torch.optim.Adadelta
withPyroOptim
.
- Adagrad(optim_args, clip_args=None)¶
Wraps
torch.optim.Adagrad
withPyroOptim
.
- Adam(optim_args, clip_args=None)¶
Wraps
torch.optim.Adam
withPyroOptim
.
- AdamW(optim_args, clip_args=None)¶
Wraps
torch.optim.AdamW
withPyroOptim
.
- SparseAdam(optim_args, clip_args=None)¶
Wraps
torch.optim.SparseAdam
withPyroOptim
.
- Adamax(optim_args, clip_args=None)¶
Wraps
torch.optim.Adamax
withPyroOptim
.
- ASGD(optim_args, clip_args=None)¶
Wraps
torch.optim.ASGD
withPyroOptim
.
- SGD(optim_args, clip_args=None)¶
Wraps
torch.optim.SGD
withPyroOptim
.
- RAdam(optim_args, clip_args=None)¶
Wraps
torch.optim.RAdam
withPyroOptim
.
- Rprop(optim_args, clip_args=None)¶
Wraps
torch.optim.Rprop
withPyroOptim
.
- RMSprop(optim_args, clip_args=None)¶
Wraps
torch.optim.RMSprop
withPyroOptim
.
- NAdam(optim_args, clip_args=None)¶
Wraps
torch.optim.NAdam
withPyroOptim
.
- LRScheduler(optim_args, clip_args=None)¶
Wraps
torch.optim.LRScheduler
withPyroLRScheduler
.
- LambdaLR(optim_args, clip_args=None)¶
Wraps
torch.optim.LambdaLR
withPyroLRScheduler
.
- MultiplicativeLR(optim_args, clip_args=None)¶
Wraps
torch.optim.MultiplicativeLR
withPyroLRScheduler
.
- StepLR(optim_args, clip_args=None)¶
Wraps
torch.optim.StepLR
withPyroLRScheduler
.
- MultiStepLR(optim_args, clip_args=None)¶
Wraps
torch.optim.MultiStepLR
withPyroLRScheduler
.
- ConstantLR(optim_args, clip_args=None)¶
Wraps
torch.optim.ConstantLR
withPyroLRScheduler
.
- LinearLR(optim_args, clip_args=None)¶
Wraps
torch.optim.LinearLR
withPyroLRScheduler
.
- ExponentialLR(optim_args, clip_args=None)¶
Wraps
torch.optim.ExponentialLR
withPyroLRScheduler
.
- SequentialLR(optim_args, clip_args=None)¶
Wraps
torch.optim.SequentialLR
withPyroLRScheduler
.
- PolynomialLR(optim_args, clip_args=None)¶
Wraps
torch.optim.PolynomialLR
withPyroLRScheduler
.
- CosineAnnealingLR(optim_args, clip_args=None)¶
Wraps
torch.optim.CosineAnnealingLR
withPyroLRScheduler
.
- ChainedScheduler(optim_args, clip_args=None)¶
Wraps
torch.optim.ChainedScheduler
withPyroLRScheduler
.
- ReduceLROnPlateau(optim_args, clip_args=None)¶
Wraps
torch.optim.ReduceLROnPlateau
withPyroLRScheduler
.
- CyclicLR(optim_args, clip_args=None)¶
Wraps
torch.optim.CyclicLR
withPyroLRScheduler
.
- CosineAnnealingWarmRestarts(optim_args, clip_args=None)¶
Wraps
torch.optim.CosineAnnealingWarmRestarts
withPyroLRScheduler
.
- OneCycleLR(optim_args, clip_args=None)¶
Wraps
torch.optim.OneCycleLR
withPyroLRScheduler
.
Higher-Order Optimizers¶
- class MultiOptimizer[source]¶
Bases:
object
Base class of optimizers that make use of higher-order derivatives.
Higher-order optimizers generally use
torch.autograd.grad()
rather thantorch.Tensor.backward()
, and therefore require a different interface from usual Pyro and PyTorch optimizers. In this interface, thestep()
method inputs aloss
tensor to be differentiated, and backpropagation is triggered one or more times inside the optimizer.Derived classes must implement
step()
to compute derivatives and update parameters in-place.Example:
tr = poutine.trace(model).get_trace(*args, **kwargs) loss = -tr.log_prob_sum() params = {name: site['value'].unconstrained() for name, site in tr.nodes.items() if site['type'] == 'param'} optim.step(loss, params)
- step(loss: torch.Tensor, params: Dict) None [source]¶
Performs an in-place optimization step on parameters given a differentiable
loss
tensor.Note that this detaches the updated tensors.
- Parameters
loss (torch.Tensor) – A differentiable tensor to be minimized. Some optimizers require this to be differentiable multiple times.
params (dict) – A dictionary mapping param name to unconstrained value as stored in the param store.
- get_step(loss: torch.Tensor, params: Dict) Dict [source]¶
Computes an optimization step of parameters given a differentiable
loss
tensor, returning the updated values.Note that this preserves derivatives on the updated tensors.
- Parameters
loss (torch.Tensor) – A differentiable tensor to be minimized. Some optimizers require this to be differentiable multiple times.
params (dict) – A dictionary mapping param name to unconstrained value as stored in the param store.
- Returns
A dictionary mapping param name to updated unconstrained value.
- Return type
- class PyroMultiOptimizer(optim: pyro.optim.optim.PyroOptim)[source]¶
Bases:
pyro.optim.multi.MultiOptimizer
Facade to wrap
PyroOptim
objects in aMultiOptimizer
interface.- step(loss: torch.Tensor, params: Dict) None [source]¶
- class TorchMultiOptimizer(optim_constructor: torch.optim.optimizer.Optimizer, optim_args: Dict)[source]¶
Bases:
pyro.optim.multi.PyroMultiOptimizer
Facade to wrap
Optimizer
objects in aMultiOptimizer
interface.
- class MixedMultiOptimizer(parts: List)[source]¶
Bases:
pyro.optim.multi.MultiOptimizer
Container class to combine different
MultiOptimizer
instances for different parameters.- Parameters
parts (list) – A list of
(names, optim)
pairs, where eachnames
is a list of parameter names, and eachoptim
is aMultiOptimizer
orPyroOptim
object to be used for the named parameters. Together thenames
should partition up all desired parameters to optimize.- Raises
ValueError – if any name is optimized by multiple optimizers.
- step(loss: torch.Tensor, params: Dict)[source]¶
- get_step(loss: torch.Tensor, params: Dict) Dict [source]¶
- class Newton(trust_radii: Dict = {})[source]¶
Bases:
pyro.optim.multi.MultiOptimizer
Implementation of
MultiOptimizer
that performs a Newton update on batched low-dimensional variables, optionally regularizing via a per-parametertrust_radius
. Seenewton_step()
for details.The result of
get_step()
will be differentiable, however the updated values fromstep()
will be detached.- Parameters
trust_radii (dict) – a dict mapping parameter name to radius of trust region. Missing names will use unregularized Newton update, equivalent to infinite trust radius.
- get_step(loss: torch.Tensor, params: Dict)[source]¶