Optimization

The module pyro.optim provides support for optimization in Pyro. In particular it provides PyroOptim, which is used to wrap PyTorch optimizers and manage optimizers for dynamically generated parameters (see the tutorial SVI Part I for a discussion). Any custom optimization algorithms are also to be found here.

Pyro Optimizers

is_scheduler(optimizer) bool[source]

Helper method to determine whether a PyTorch object is either a PyTorch optimizer (return false) or a optimizer wrapped in an LRScheduler e.g. a ReduceLROnPlateau or subclasses of _LRScheduler (return true).

class PyroOptim(optim_constructor: Union[Callable, torch.optim.optimizer.Optimizer, Type[torch.optim.optimizer.Optimizer]], optim_args: Union[Dict, Callable[[...], Dict]], clip_args: Optional[Union[Dict, Callable[[...], Dict]]] = None)[source]

Bases: object

A wrapper for torch.optim.Optimizer objects that helps with managing dynamically generated parameters.

Parameters
  • optim_constructor – a torch.optim.Optimizer

  • optim_args – a dictionary of learning arguments for the optimizer or a callable that returns such dictionaries

  • clip_args – a dictionary of clip_norm and/or clip_value args or a callable that returns such dictionaries

__call__(params: Union[List, ValuesView], *args, **kwargs) None[source]
Parameters

params (an iterable of strings) – a list of parameters

Do an optimization step for each param in params. If a given param has never been seen before, initialize an optimizer for it.

get_state() Dict[source]

Get state associated with all the optimizers in the form of a dictionary with key-value pairs (parameter name, optim state dicts)

set_state(state_dict: Dict) None[source]

Set the state associated with all the optimizers using the state obtained from a previous call to get_state()

save(filename: str) None[source]
Parameters

filename (str) – file name to save to

Save optimizer state to disk

load(filename: str, map_location=None) None[source]
Parameters
  • filename (str) – file name to load from

  • map_location (function, torch.device, string or a dict) – torch.load() map_location parameter

Load optimizer state from disk

AdagradRMSProp(optim_args: Dict) pyro.optim.optim.PyroOptim[source]

Wraps pyro.optim.adagrad_rmsprop.AdagradRMSProp with PyroOptim.

ClippedAdam(optim_args: Dict) pyro.optim.optim.PyroOptim[source]

Wraps pyro.optim.clipped_adam.ClippedAdam with PyroOptim.

DCTAdam(optim_args: Dict) pyro.optim.optim.PyroOptim[source]

Wraps pyro.optim.dct_adam.DCTAdam with PyroOptim.

class PyroLRScheduler(scheduler_constructor, optim_args: Dict, clip_args: Optional[Dict] = None)[source]

Bases: pyro.optim.optim.PyroOptim

A wrapper for lr_scheduler objects that adjusts learning rates for dynamically generated parameters.

Parameters
  • scheduler_constructor – a lr_scheduler

  • optim_args – a dictionary of learning arguments for the optimizer or a callable that returns such dictionaries. must contain the key ‘optimizer’ with pytorch optimizer value

  • clip_args – a dictionary of clip_norm and/or clip_value args or a callable that returns such dictionaries.

Example:

optimizer = torch.optim.SGD
scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1})
svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO())
for i in range(epochs):
    for minibatch in DataLoader(dataset, batch_size):
        svi.step(minibatch)
    scheduler.step()
__call__(params: Union[List, ValuesView], *args, **kwargs) None[source]
step(*args, **kwargs) None[source]

Takes the same arguments as the PyTorch scheduler (e.g. optional loss for ReduceLROnPlateau)

class AdagradRMSProp(params, eta: float = 1.0, delta: float = 1e-16, t: float = 0.1)[source]

Bases: torch.optim.optimizer.Optimizer

Implements a mash-up of the Adagrad algorithm and RMSProp. For the precise update equation see equations 10 and 11 in reference [1].

References: [1] ‘Automatic Differentiation Variational Inference’, Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei URL: https://arxiv.org/abs/1603.00788 [2] ‘Lecture 6.5 RmsProp: Divide the gradient by a running average of its recent magnitude’, Tieleman, T. and Hinton, G., COURSERA: Neural Networks for Machine Learning. [3] ‘Adaptive subgradient methods for online learning and stochastic optimization’, Duchi, John, Hazan, E and Singer, Y.

Arguments:

Parameters
  • params – iterable of parameters to optimize or dicts defining parameter groups

  • eta (float) – sets the step size scale (optional; default: 1.0)

  • t (float) – t, optional): momentum parameter (optional; default: 0.1)

  • delta (float) – modulates the exponent that controls how the step size scales (optional: default: 1e-16)

share_memory() None[source]
step(closure: Optional[Callable] = None) Optional[Any][source]

Performs a single optimization step.

Parameters

closure – A (optional) closure that reevaluates the model and returns the loss.

class ClippedAdam(params, lr: float = 0.001, betas: Tuple = (0.9, 0.999), eps: float = 1e-08, weight_decay=0, clip_norm: float = 10.0, lrd: float = 1.0)[source]

Bases: torch.optim.optimizer.Optimizer

Parameters
  • params – iterable of parameters to optimize or dicts defining parameter groups

  • lr – learning rate (default: 1e-3)

  • betas (Tuple) – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))

  • eps – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay – weight decay (L2 penalty) (default: 0)

  • clip_norm – magnitude of norm to which gradients are clipped (default: 10.0)

  • lrd – rate at which learning rate decays (default: 1.0)

Small modification to the Adam algorithm implemented in torch.optim.Adam to include gradient clipping and learning rate decay.

Reference

A Method for Stochastic Optimization, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980

step(closure: Optional[Callable] = None) Optional[Any][source]
Parameters

closure – An optional closure that reevaluates the model and returns the loss.

Performs a single optimization step.

class HorovodOptimizer(pyro_optim: pyro.optim.optim.PyroOptim, **horovod_kwargs)[source]

Bases: pyro.optim.optim.PyroOptim

Distributed wrapper for a PyroOptim optimizer.

This class wraps a PyroOptim object similar to the way horovod.torch.DistributedOptimizer() wraps a torch.optim.Optimizer.

Note

This requires horovod.torch to be installed, e.g. via pip install pyro[horovod]. For details see https://horovod.readthedocs.io/en/stable/install.html

Param

A Pyro optimizer instance.

Parameters

**horovod_kwargs – Extra parameters passed to horovod.torch.DistributedOptimizer().

__call__(params: Union[List, ValuesView], *args, **kwargs) None[source]

PyTorch Optimizers

Adadelta(optim_args, clip_args=None)

Wraps torch.optim.Adadelta with PyroOptim.

Adagrad(optim_args, clip_args=None)

Wraps torch.optim.Adagrad with PyroOptim.

Adam(optim_args, clip_args=None)

Wraps torch.optim.Adam with PyroOptim.

AdamW(optim_args, clip_args=None)

Wraps torch.optim.AdamW with PyroOptim.

SparseAdam(optim_args, clip_args=None)

Wraps torch.optim.SparseAdam with PyroOptim.

Adamax(optim_args, clip_args=None)

Wraps torch.optim.Adamax with PyroOptim.

ASGD(optim_args, clip_args=None)

Wraps torch.optim.ASGD with PyroOptim.

SGD(optim_args, clip_args=None)

Wraps torch.optim.SGD with PyroOptim.

RAdam(optim_args, clip_args=None)

Wraps torch.optim.RAdam with PyroOptim.

Rprop(optim_args, clip_args=None)

Wraps torch.optim.Rprop with PyroOptim.

RMSprop(optim_args, clip_args=None)

Wraps torch.optim.RMSprop with PyroOptim.

NAdam(optim_args, clip_args=None)

Wraps torch.optim.NAdam with PyroOptim.

LRScheduler(optim_args, clip_args=None)

Wraps torch.optim.LRScheduler with PyroLRScheduler.

LambdaLR(optim_args, clip_args=None)

Wraps torch.optim.LambdaLR with PyroLRScheduler.

MultiplicativeLR(optim_args, clip_args=None)

Wraps torch.optim.MultiplicativeLR with PyroLRScheduler.

StepLR(optim_args, clip_args=None)

Wraps torch.optim.StepLR with PyroLRScheduler.

MultiStepLR(optim_args, clip_args=None)

Wraps torch.optim.MultiStepLR with PyroLRScheduler.

ConstantLR(optim_args, clip_args=None)

Wraps torch.optim.ConstantLR with PyroLRScheduler.

LinearLR(optim_args, clip_args=None)

Wraps torch.optim.LinearLR with PyroLRScheduler.

ExponentialLR(optim_args, clip_args=None)

Wraps torch.optim.ExponentialLR with PyroLRScheduler.

SequentialLR(optim_args, clip_args=None)

Wraps torch.optim.SequentialLR with PyroLRScheduler.

PolynomialLR(optim_args, clip_args=None)

Wraps torch.optim.PolynomialLR with PyroLRScheduler.

CosineAnnealingLR(optim_args, clip_args=None)

Wraps torch.optim.CosineAnnealingLR with PyroLRScheduler.

ChainedScheduler(optim_args, clip_args=None)

Wraps torch.optim.ChainedScheduler with PyroLRScheduler.

ReduceLROnPlateau(optim_args, clip_args=None)

Wraps torch.optim.ReduceLROnPlateau with PyroLRScheduler.

CyclicLR(optim_args, clip_args=None)

Wraps torch.optim.CyclicLR with PyroLRScheduler.

CosineAnnealingWarmRestarts(optim_args, clip_args=None)

Wraps torch.optim.CosineAnnealingWarmRestarts with PyroLRScheduler.

OneCycleLR(optim_args, clip_args=None)

Wraps torch.optim.OneCycleLR with PyroLRScheduler.

Higher-Order Optimizers

class MultiOptimizer[source]

Bases: object

Base class of optimizers that make use of higher-order derivatives.

Higher-order optimizers generally use torch.autograd.grad() rather than torch.Tensor.backward(), and therefore require a different interface from usual Pyro and PyTorch optimizers. In this interface, the step() method inputs a loss tensor to be differentiated, and backpropagation is triggered one or more times inside the optimizer.

Derived classes must implement step() to compute derivatives and update parameters in-place.

Example:

tr = poutine.trace(model).get_trace(*args, **kwargs)
loss = -tr.log_prob_sum()
params = {name: site['value'].unconstrained()
          for name, site in tr.nodes.items()
          if site['type'] == 'param'}
optim.step(loss, params)
step(loss: torch.Tensor, params: Dict) None[source]

Performs an in-place optimization step on parameters given a differentiable loss tensor.

Note that this detaches the updated tensors.

Parameters
  • loss (torch.Tensor) – A differentiable tensor to be minimized. Some optimizers require this to be differentiable multiple times.

  • params (dict) – A dictionary mapping param name to unconstrained value as stored in the param store.

get_step(loss: torch.Tensor, params: Dict) Dict[source]

Computes an optimization step of parameters given a differentiable loss tensor, returning the updated values.

Note that this preserves derivatives on the updated tensors.

Parameters
  • loss (torch.Tensor) – A differentiable tensor to be minimized. Some optimizers require this to be differentiable multiple times.

  • params (dict) – A dictionary mapping param name to unconstrained value as stored in the param store.

Returns

A dictionary mapping param name to updated unconstrained value.

Return type

dict

class PyroMultiOptimizer(optim: pyro.optim.optim.PyroOptim)[source]

Bases: pyro.optim.multi.MultiOptimizer

Facade to wrap PyroOptim objects in a MultiOptimizer interface.

step(loss: torch.Tensor, params: Dict) None[source]
class TorchMultiOptimizer(optim_constructor: torch.optim.optimizer.Optimizer, optim_args: Dict)[source]

Bases: pyro.optim.multi.PyroMultiOptimizer

Facade to wrap Optimizer objects in a MultiOptimizer interface.

class MixedMultiOptimizer(parts: List)[source]

Bases: pyro.optim.multi.MultiOptimizer

Container class to combine different MultiOptimizer instances for different parameters.

Parameters

parts (list) – A list of (names, optim) pairs, where each names is a list of parameter names, and each optim is a MultiOptimizer or PyroOptim object to be used for the named parameters. Together the names should partition up all desired parameters to optimize.

Raises

ValueError – if any name is optimized by multiple optimizers.

step(loss: torch.Tensor, params: Dict)[source]
get_step(loss: torch.Tensor, params: Dict) Dict[source]
class Newton(trust_radii: Dict = {})[source]

Bases: pyro.optim.multi.MultiOptimizer

Implementation of MultiOptimizer that performs a Newton update on batched low-dimensional variables, optionally regularizing via a per-parameter trust_radius. See newton_step() for details.

The result of get_step() will be differentiable, however the updated values from step() will be detached.

Parameters

trust_radii (dict) – a dict mapping parameter name to radius of trust region. Missing names will use unregularized Newton update, equivalent to infinite trust radius.

get_step(loss: torch.Tensor, params: Dict)[source]