Source code for pyro.optim.multi

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List

import torch

from pyro.ops.newton import newton_step
from pyro.optim.optim import PyroOptim


[docs]class MultiOptimizer: """ Base class of optimizers that make use of higher-order derivatives. Higher-order optimizers generally use :func:`torch.autograd.grad` rather than :meth:`torch.Tensor.backward`, and therefore require a different interface from usual Pyro and PyTorch optimizers. In this interface, the :meth:`step` method inputs a ``loss`` tensor to be differentiated, and backpropagation is triggered one or more times inside the optimizer. Derived classes must implement :meth:`step` to compute derivatives and update parameters in-place. Example:: tr = poutine.trace(model).get_trace(*args, **kwargs) loss = -tr.log_prob_sum() params = {name: site['value'].unconstrained() for name, site in tr.nodes.items() if site['type'] == 'param'} optim.step(loss, params) """
[docs] def step(self, loss: torch.Tensor, params: Dict) -> None: """ Performs an in-place optimization step on parameters given a differentiable ``loss`` tensor. Note that this detaches the updated tensors. :param torch.Tensor loss: A differentiable tensor to be minimized. Some optimizers require this to be differentiable multiple times. :param dict params: A dictionary mapping param name to unconstrained value as stored in the param store. """ updated_values = self.get_step(loss, params) for name, value in params.items(): with torch.no_grad(): # we need to detach because updated_value may depend on value value.copy_(updated_values[name].detach())
[docs] def get_step(self, loss: torch.Tensor, params: Dict) -> Dict: """ Computes an optimization step of parameters given a differentiable ``loss`` tensor, returning the updated values. Note that this preserves derivatives on the updated tensors. :param torch.Tensor loss: A differentiable tensor to be minimized. Some optimizers require this to be differentiable multiple times. :param dict params: A dictionary mapping param name to unconstrained value as stored in the param store. :return: A dictionary mapping param name to updated unconstrained value. :rtype: dict """ raise NotImplementedError
[docs]class PyroMultiOptimizer(MultiOptimizer): """ Facade to wrap :class:`~pyro.optim.optim.PyroOptim` objects in a :class:`MultiOptimizer` interface. """ def __init__(self, optim: PyroOptim) -> None: if not isinstance(optim, PyroOptim): raise TypeError( "Expected a PyroOptim object but got a {}".format(type(optim)) ) self.optim = optim
[docs] def step(self, loss: torch.Tensor, params: Dict) -> None: values = params.values() grads = torch.autograd.grad(loss, values, create_graph=True) # type: ignore for x, g in zip(values, grads): x.grad = g self.optim(values)
[docs]class TorchMultiOptimizer(PyroMultiOptimizer): """ Facade to wrap :class:`~torch.optim.Optimizer` objects in a :class:`MultiOptimizer` interface. """ def __init__(self, optim_constructor: torch.optim.Optimizer, optim_args: Dict): optim = PyroOptim(optim_constructor, optim_args) super().__init__(optim)
[docs]class MixedMultiOptimizer(MultiOptimizer): """ Container class to combine different :class:`MultiOptimizer` instances for different parameters. :param list parts: A list of ``(names, optim)`` pairs, where each ``names`` is a list of parameter names, and each ``optim`` is a :class:`MultiOptimizer` or :class:`~pyro.optim.optim.PyroOptim` object to be used for the named parameters. Together the ``names`` should partition up all desired parameters to optimize. :raises ValueError: if any name is optimized by multiple optimizers. """ def __init__(self, parts: List) -> None: optim_dict: Dict = {} self.parts = [] for names_part, optim in parts: if isinstance(optim, PyroOptim): optim = PyroMultiOptimizer(optim) for name in names_part: if name in optim_dict: raise ValueError( "Attempted to optimize parameter '{}' by two different optimizers: " "{} vs {}".format(name, optim_dict[name], optim) ) optim_dict[name] = optim self.parts.append((names_part, optim))
[docs] def step(self, loss: torch.Tensor, params: Dict): for names_part, optim in self.parts: optim.step(loss, {name: params[name] for name in names_part})
[docs] def get_step(self, loss: torch.Tensor, params: Dict) -> Dict: updated_values = {} for names_part, optim in self.parts: updated_values.update( optim.get_step(loss, {name: params[name] for name in names_part}) ) return updated_values
[docs]class Newton(MultiOptimizer): """ Implementation of :class:`MultiOptimizer` that performs a Newton update on batched low-dimensional variables, optionally regularizing via a per-parameter ``trust_radius``. See :func:`~pyro.ops.newton.newton_step` for details. The result of :meth:`get_step` will be differentiable, however the updated values from :meth:`step` will be detached. :param dict trust_radii: a dict mapping parameter name to radius of trust region. Missing names will use unregularized Newton update, equivalent to infinite trust radius. """ def __init__(self, trust_radii: Dict = {}): self.trust_radii = trust_radii
[docs] def get_step(self, loss: torch.Tensor, params: Dict): updated_values = {} for name, value in params.items(): trust_radius = self.trust_radii.get(name) # type: ignore updated_value, cov = newton_step(loss, value, trust_radius) updated_values[name] = updated_value return updated_values