Optimization

The module pyro.optim provides support for optimization in Pyro. In particular it provides PyroOptim, which is used to wrap PyTorch optimizers and manage optimizers for dynamically generated parameters (see the tutorial SVI Part I for a discussion). Any custom optimization algorithms are also to be found here.

Pyro Optimizers

class PyroOptim(optim_constructor, optim_args)[source]

Bases: object

A wrapper for torch.optim.Optimizer objects that helps with managing dynamically generated parameters.

Parameters:
  • optim_constructor – a torch.optim.Optimizer
  • optim_args – a dictionary of learning arguments for the optimizer or a callable that returns such dictionaries
__call__(params, *args, **kwargs)[source]
Parameters:params (an iterable of strings) – a list of parameters

Do an optimization step for each param in params. If a given param has never been seen before, initialize an optimizer for it.

get_state()[source]

Get state associated with all the optimizers in the form of a dictionary with key-value pairs (parameter name, optim state dicts)

set_state(state_dict)[source]

Set the state associated with all the optimizers using the state obtained from a previous call to get_state()

save(filename)[source]
Parameters:filename – file name to save to

Save optimizer state to disk

load(filename)[source]
Parameters:filename – file name to load from

Load optimizer state from disk

AdagradRMSProp(optim_args)[source]

A wrapper for an optimizer that is a mash-up of Adagrad and RMSprop.

ClippedAdam(optim_args)[source]

A wrapper for a modification of the Adam optimization algorithm that supports gradient clipping.

class PyroLRScheduler(scheduler_constructor, optim_args)[source]

Bases: pyro.optim.optim.PyroOptim

A wrapper for lr_scheduler objects that adjusts learning rates for dynamically generated parameters.

Parameters:
  • scheduler_constructor – a lr_scheduler
  • optim_args – a dictionary of learning arguments for the optimizer or a callable that returns such dictionaries. must contain the key ‘optimizer’ with pytorch optimizer value

Example:

optimizer = torch.optim.SGD
scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 'optim_args': {'lr': 0.01}, 'gamma': 0.1})
svi = SVI(model, guide, pyro_scheduler, loss=TraceGraph_ELBO())
for i in range(epochs):
    for minibatch in DataLoader(dataset, batch_size):
        svi.step(minibatch)
    scheduler.step(epoch=i)
step(*args, **kwargs)[source]

Takes the same arguments as the PyTorch scheduler (optional epoch kwarg or loss in for ReduceLROnPlateau)

PyTorch Optimizers

Adamax(optim_args)

Wraps torch.optim.Adamax with PyroOptim.

Adagrad(optim_args)

Wraps torch.optim.Adagrad with PyroOptim.

SGD(optim_args)

Wraps torch.optim.SGD with PyroOptim.

Adam(optim_args)

Wraps torch.optim.Adam with PyroOptim.

Rprop(optim_args)

Wraps torch.optim.Rprop with PyroOptim.

ASGD(optim_args)

Wraps torch.optim.ASGD with PyroOptim.

RMSprop(optim_args)

Wraps torch.optim.RMSprop with PyroOptim.

SparseAdam(optim_args)

Wraps torch.optim.SparseAdam with PyroOptim.

Adadelta(optim_args)

Wraps torch.optim.Adadelta with PyroOptim.

MultiStepLR(optim_args)

Wraps torch.optim.MultiStepLR with PyroLRScheduler.

ReduceLROnPlateau(optim_args)

Wraps torch.optim.ReduceLROnPlateau with PyroLRScheduler.

StepLR(optim_args)

Wraps torch.optim.StepLR with PyroLRScheduler.

CosineAnnealingWarmRestarts(optim_args)

Wraps torch.optim.CosineAnnealingWarmRestarts with PyroLRScheduler.

CosineAnnealingLR(optim_args)

Wraps torch.optim.CosineAnnealingLR with PyroLRScheduler.

CyclicLR(optim_args)

Wraps torch.optim.CyclicLR with PyroLRScheduler.

LambdaLR(optim_args)

Wraps torch.optim.LambdaLR with PyroLRScheduler.

ExponentialLR(optim_args)

Wraps torch.optim.ExponentialLR with PyroLRScheduler.

Higher-Order Optimizers

class MultiOptimizer[source]

Bases: object

Base class of optimizers that make use of higher-order derivatives.

Higher-order optimizers generally use torch.autograd.grad() rather than torch.Tensor.backward(), and therefore require a different interface from usual Pyro and PyTorch optimizers. In this interface, the step() method inputs a loss tensor to be differentiated, and backpropagation is triggered one or more times inside the optimizer.

Derived classes must implement step() to compute derivatives and update parameters in-place.

Example:

tr = poutine.trace(model).get_trace(*args, **kwargs)
loss = -tr.log_prob_sum()
params = {name: site['value'].unconstrained()
          for name, site in tr.nodes.items()
          if site['type'] == 'param'}
optim.step(loss, params)
step(loss, params)[source]

Performs an in-place optimization step on parameters given a differentiable loss tensor.

Note that this detaches the updated tensors.

Parameters:
  • loss (torch.Tensor) – A differentiable tensor to be minimized. Some optimizers require this to be differentiable multiple times.
  • params (dict) – A dictionary mapping param name to unconstrained value as stored in the param store.
get_step(loss, params)[source]

Computes an optimization step of parameters given a differentiable loss tensor, returning the updated values.

Note that this preserves derivatives on the updated tensors.

Parameters:
  • loss (torch.Tensor) – A differentiable tensor to be minimized. Some optimizers require this to be differentiable multiple times.
  • params (dict) – A dictionary mapping param name to unconstrained value as stored in the param store.
Returns:

A dictionary mapping param name to updated unconstrained value.

Return type:

dict

class PyroMultiOptimizer(optim)[source]

Bases: pyro.optim.multi.MultiOptimizer

Facade to wrap PyroOptim objects in a MultiOptimizer interface.

step(loss, params)[source]
class TorchMultiOptimizer(optim_constructor, optim_args)[source]

Bases: pyro.optim.multi.PyroMultiOptimizer

Facade to wrap Optimizer objects in a MultiOptimizer interface.

class MixedMultiOptimizer(parts)[source]

Bases: pyro.optim.multi.MultiOptimizer

Container class to combine different MultiOptimizer instances for different parameters.

Parameters:parts (list) – A list of (names, optim) pairs, where each names is a list of parameter names, and each optim is a MultiOptimizer or PyroOptim object to be used for the named parameters. Together the names should partition up all desired parameters to optimize.
Raises:ValueError – if any name is optimized by multiple optimizers.
step(loss, params)[source]
get_step(loss, params)[source]
class Newton(trust_radii={})[source]

Bases: pyro.optim.multi.MultiOptimizer

Implementation of MultiOptimizer that performs a Newton update on batched low-dimensional variables, optionally regularizing via a per-parameter trust_radius. See newton_step() for details.

The result of get_step() will be differentiable, however the updated values from step() will be detached.

Parameters:trust_radii (dict) – a dict mapping parameter name to radius of trust region. Missing names will use unregularized Newton update, equivalent to infinite trust radius.
get_step(loss, params)[source]