# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List
import torch
from pyro.ops.newton import newton_step
from pyro.optim.optim import PyroOptim
[docs]class MultiOptimizer:
"""
Base class of optimizers that make use of higher-order derivatives.
Higher-order optimizers generally use :func:`torch.autograd.grad` rather
than :meth:`torch.Tensor.backward`, and therefore require a different
interface from usual Pyro and PyTorch optimizers. In this interface,
the :meth:`step` method inputs a ``loss`` tensor to be differentiated,
and backpropagation is triggered one or more times inside the optimizer.
Derived classes must implement :meth:`step` to compute derivatives and
update parameters in-place.
Example::
tr = poutine.trace(model).get_trace(*args, **kwargs)
loss = -tr.log_prob_sum()
params = {name: site['value'].unconstrained()
for name, site in tr.nodes.items()
if site['type'] == 'param'}
optim.step(loss, params)
"""
[docs] def step(self, loss: torch.Tensor, params: Dict) -> None:
"""
Performs an in-place optimization step on parameters given a
differentiable ``loss`` tensor.
Note that this detaches the updated tensors.
:param torch.Tensor loss: A differentiable tensor to be minimized.
Some optimizers require this to be differentiable multiple times.
:param dict params: A dictionary mapping param name to unconstrained
value as stored in the param store.
"""
updated_values = self.get_step(loss, params)
for name, value in params.items():
with torch.no_grad():
# we need to detach because updated_value may depend on value
value.copy_(updated_values[name].detach())
[docs] def get_step(self, loss: torch.Tensor, params: Dict) -> Dict:
"""
Computes an optimization step of parameters given a differentiable
``loss`` tensor, returning the updated values.
Note that this preserves derivatives on the updated tensors.
:param torch.Tensor loss: A differentiable tensor to be minimized.
Some optimizers require this to be differentiable multiple times.
:param dict params: A dictionary mapping param name to unconstrained
value as stored in the param store.
:return: A dictionary mapping param name to updated unconstrained
value.
:rtype: dict
"""
raise NotImplementedError
[docs]class PyroMultiOptimizer(MultiOptimizer):
"""
Facade to wrap :class:`~pyro.optim.optim.PyroOptim` objects
in a :class:`MultiOptimizer` interface.
"""
def __init__(self, optim: PyroOptim) -> None:
if not isinstance(optim, PyroOptim):
raise TypeError(
"Expected a PyroOptim object but got a {}".format(type(optim))
)
self.optim = optim
[docs] def step(self, loss: torch.Tensor, params: Dict) -> None:
values = params.values()
grads = torch.autograd.grad(loss, values, create_graph=True) # type: ignore
for x, g in zip(values, grads):
x.grad = g
self.optim(values)
[docs]class TorchMultiOptimizer(PyroMultiOptimizer):
"""
Facade to wrap :class:`~torch.optim.Optimizer` objects
in a :class:`MultiOptimizer` interface.
"""
def __init__(self, optim_constructor: torch.optim.Optimizer, optim_args: Dict):
optim = PyroOptim(optim_constructor, optim_args)
super().__init__(optim)
[docs]class MixedMultiOptimizer(MultiOptimizer):
"""
Container class to combine different :class:`MultiOptimizer` instances for
different parameters.
:param list parts: A list of ``(names, optim)`` pairs, where each
``names`` is a list of parameter names, and each ``optim`` is a
:class:`MultiOptimizer` or :class:`~pyro.optim.optim.PyroOptim` object
to be used for the named parameters. Together the ``names`` should
partition up all desired parameters to optimize.
:raises ValueError: if any name is optimized by multiple optimizers.
"""
def __init__(self, parts: List) -> None:
optim_dict: Dict = {}
self.parts = []
for names_part, optim in parts:
if isinstance(optim, PyroOptim):
optim = PyroMultiOptimizer(optim)
for name in names_part:
if name in optim_dict:
raise ValueError(
"Attempted to optimize parameter '{}' by two different optimizers: "
"{} vs {}".format(name, optim_dict[name], optim)
)
optim_dict[name] = optim
self.parts.append((names_part, optim))
[docs] def step(self, loss: torch.Tensor, params: Dict):
for names_part, optim in self.parts:
optim.step(loss, {name: params[name] for name in names_part})
[docs] def get_step(self, loss: torch.Tensor, params: Dict) -> Dict:
updated_values = {}
for names_part, optim in self.parts:
updated_values.update(
optim.get_step(loss, {name: params[name] for name in names_part})
)
return updated_values
[docs]class Newton(MultiOptimizer):
"""
Implementation of :class:`MultiOptimizer` that performs a Newton update
on batched low-dimensional variables, optionally regularizing via a
per-parameter ``trust_radius``. See :func:`~pyro.ops.newton.newton_step`
for details.
The result of :meth:`get_step` will be differentiable, however the
updated values from :meth:`step` will be detached.
:param dict trust_radii: a dict mapping parameter name to radius of trust
region. Missing names will use unregularized Newton update, equivalent
to infinite trust radius.
"""
def __init__(self, trust_radii: Dict = {}):
self.trust_radii = trust_radii
[docs] def get_step(self, loss: torch.Tensor, params: Dict):
updated_values = {}
for name, value in params.items():
trust_radius = self.trust_radii.get(name) # type: ignore
updated_value, cov = newton_step(loss, value, trust_radius)
updated_values[name] = updated_value
return updated_values