# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import re
import warnings
import weakref
from contextlib import contextmanager
from typing import (
Callable,
Dict,
ItemsView,
Iterator,
KeysView,
Optional,
Tuple,
Union,
)
import torch
from torch.distributions import constraints, transform_to
from torch.serialization import MAP_LOCATION
from typing_extensions import TypedDict
[docs]class StateDict(TypedDict):
params: Dict[str, torch.Tensor]
constraints: Dict[str, constraints.Constraint]
[docs]class ParamStoreDict:
"""
Global store for parameters in Pyro. This is basically a key-value store.
The typical user interacts with the ParamStore primarily through the
primitive `pyro.param`.
See `Introduction <http://pyro.ai/examples/intro_long.html>`_ for further discussion
and `SVI Part I <http://pyro.ai/examples/svi_part_i.html>`_ for some examples.
Some things to bear in mind when using parameters in Pyro:
- parameters must be assigned unique names
- the `init_tensor` argument to `pyro.param` is only used the first time that a given (named)
parameter is registered with Pyro.
- for this reason, a user may need to use the `clear()` method if working in a REPL in order to
get the desired behavior. this method can also be invoked with `pyro.clear_param_store()`.
- the internal name of a parameter within a PyTorch `nn.Module` that has been registered with
Pyro is prepended with the Pyro name of the module. so nothing prevents the user from having
two different modules each of which contains a parameter named `weight`. by contrast, a user
can only have one top-level parameter named `weight` (outside of any module).
- parameters can be saved and loaded from disk using `save` and `load`.
- in general parameters are associated with both *constrained* and *unconstrained* values. for
example, under the hood a parameter that is constrained to be positive is represented as an
unconstrained tensor in log space.
"""
# -------------------------------------------------------------------------------
# New dict-like interface
def __init__(self) -> None:
"""
initialize ParamStore data structures
"""
self._params: Dict[str, torch.Tensor] = (
{}
) # dictionary from param name to param
self._param_to_name: Dict[torch.Tensor, str] = (
{}
) # dictionary from unconstrained param to param name
self._constraints: Dict[str, constraints.Constraint] = (
{}
) # dictionary from param name to constraint object
[docs] def clear(self) -> None:
"""
Clear the ParamStore
"""
self._params = {}
self._param_to_name = {}
self._constraints = {}
[docs] def items(self) -> Iterator[Tuple[str, torch.Tensor]]:
"""
Iterate over ``(name, constrained_param)`` pairs. Note that `constrained_param` is
in the constrained (i.e. user-facing) space.
"""
for name in self._params:
yield name, self[name]
[docs] def keys(self) -> KeysView[str]:
"""
Iterate over param names.
"""
return self._params.keys()
[docs] def values(self) -> Iterator[torch.Tensor]:
"""
Iterate over constrained parameter values.
"""
for name, constrained_param in self.items():
yield constrained_param
def __bool__(self) -> bool:
return bool(self._params)
def __len__(self) -> int:
return len(self._params)
def __contains__(self, name: str) -> bool:
return name in self._params
def __iter__(self) -> Iterator[str]:
"""
Iterate over param names.
"""
return iter(self.keys())
def __delitem__(self, name) -> None:
"""
Remove a parameter from the param store.
"""
unconstrained_value = self._params.pop(name)
self._param_to_name.pop(unconstrained_value)
self._constraints.pop(name)
def __getitem__(self, name: str) -> torch.Tensor:
"""
Get the *constrained* value of a named parameter.
"""
unconstrained_value = self._params[name]
# compute the constrained value
constraint = self._constraints[name]
constrained_value: torch.Tensor = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value) # type: ignore[attr-defined]
return constrained_value
def __setitem__(self, name: str, new_constrained_value: torch.Tensor) -> None:
"""
Set the constrained value of an existing parameter, or the value of a
new *unconstrained* parameter. To declare a new parameter with
constraint, use :meth:`setdefault`.
"""
# store constraint, defaulting to unconstrained
constraint = self._constraints.setdefault(name, constraints.real)
# compute the unconstrained value
with torch.no_grad():
# FIXME should we .detach() the new_constrained_value?
unconstrained_value = transform_to(constraint).inv(new_constrained_value)
unconstrained_value = unconstrained_value.contiguous()
unconstrained_value.requires_grad_(True)
# store a bidirectional mapping between name and unconstrained tensor
self._params[name] = unconstrained_value
self._param_to_name[unconstrained_value] = name
[docs] def setdefault(
self,
name: str,
init_constrained_value: Union[torch.Tensor, Callable[[], torch.Tensor]],
constraint: constraints.Constraint = constraints.real,
) -> torch.Tensor:
"""
Retrieve a *constrained* parameter value from the ``ParamStoreDict`` if it exists,
otherwise set the initial value. Note that this is a little fancier than
:meth:`dict.setdefault`.
If the parameter already exists, ``init_constrained_tensor`` will be ignored. To avoid
expensive creation of ``init_constrained_tensor`` you can wrap it in a ``lambda`` that
will only be evaluated if the parameter does not already exist::
param_store.get("foo", lambda: (0.001 * torch.randn(1000, 1000)).exp(),
constraint=constraints.positive)
:param str name: parameter name
:param init_constrained_value: initial constrained value
:type init_constrained_value: torch.Tensor or callable returning a torch.Tensor
:param constraint: torch constraint object
:type constraint: ~torch.distributions.constraints.Constraint
:returns: constrained parameter value
:rtype: torch.Tensor
"""
if name not in self._params:
# set the constraint
self._constraints[name] = constraint
# evaluate the lazy value
if callable(init_constrained_value):
init_constrained_value = init_constrained_value()
# set the initial value
self[name] = init_constrained_value
# get the param, which is guaranteed to exist
return self[name]
# -------------------------------------------------------------------------------
# Old non-dict interface
[docs] def named_parameters(self) -> ItemsView[str, torch.Tensor]:
"""
Returns an iterator over ``(name, unconstrained_value)`` tuples for
each parameter in the ParamStore. Note that, in the event the parameter is constrained,
`unconstrained_value` is in the unconstrained space implicitly used by the constraint.
"""
return self._params.items()
[docs] def get_all_param_names(self) -> KeysView[str]:
warnings.warn(
"ParamStore.get_all_param_names() is deprecated; use .keys() instead.",
DeprecationWarning,
)
return self.keys()
[docs] def replace_param(
self, param_name: str, new_param: torch.Tensor, old_param: torch.Tensor
) -> None:
warnings.warn(
"ParamStore.replace_param() is deprecated; use .__setitem__() instead.",
DeprecationWarning,
)
assert self._params[param_name] is old_param.unconstrained() # type: ignore[attr-defined]
self[param_name] = new_param
[docs] def get_param(
self,
name: str,
init_tensor: Optional[torch.Tensor] = None,
constraint: constraints.Constraint = constraints.real,
event_dim: Optional[int] = None,
) -> torch.Tensor:
"""
Get parameter from its name. If it does not yet exist in the
ParamStore, it will be created and stored.
The Pyro primitive `pyro.param` dispatches to this method.
:param name: parameter name
:type name: str
:param init_tensor: initial tensor
:type init_tensor: torch.Tensor
:param constraint: torch constraint
:type constraint: torch.distributions.constraints.Constraint
:param int event_dim: (ignored)
:returns: parameter
:rtype: torch.Tensor
"""
if init_tensor is None:
return self[name]
else:
return self.setdefault(name, init_tensor, constraint)
[docs] def match(self, name: str) -> Dict[str, torch.Tensor]:
"""
Get all parameters that match regex. The parameter must exist.
:param name: regular expression
:type name: str
:returns: dict with key param name and value torch Tensor
"""
pattern = re.compile(name)
return {name: self[name] for name in self if pattern.match(name)}
[docs] def param_name(self, p: torch.Tensor) -> Optional[str]:
"""
Get parameter name from parameter
:param p: parameter
:returns: parameter name
"""
return self._param_to_name.get(p)
# -------------------------------------------------------------------------------
# Persistence interface
[docs] def get_state(self) -> StateDict:
"""
Get the ParamStore state.
"""
params = self._params.copy()
# Remove weakrefs in preparation for pickling.
for param in params.values():
param.__dict__.pop("unconstrained", None)
state: StateDict = {"params": params, "constraints": self._constraints.copy()}
return state
[docs] def set_state(self, state: StateDict) -> None:
"""
Set the ParamStore state using state from a previous :meth:`get_state` call
"""
assert isinstance(state, dict), "malformed ParamStore state"
assert set(state.keys()) == set(
["params", "constraints"]
), "malformed ParamStore keys {}".format(state.keys())
for param_name, param in state["params"].items():
self._params[param_name] = param
self._param_to_name[param] = param_name
for param_name, constraint in state["constraints"].items():
if isinstance(constraint, type(constraints.real)):
# Work around lack of hash & equality comparison on constraints.
constraint = constraints.real
self._constraints[param_name] = constraint
[docs] def save(self, filename: str) -> None:
"""
Save parameters to file
:param filename: file name to save to
:type filename: str
"""
with open(filename, "wb") as output_file:
torch.save(self.get_state(), output_file)
[docs] def load(self, filename: str, map_location: MAP_LOCATION = None) -> None:
"""
Loads parameters from file
.. note::
If using :meth:`pyro.module` on parameters loaded from
disk, be sure to set the ``update_module_params`` flag::
pyro.get_param_store().load('saved_params.save')
pyro.module('module', nn, update_module_params=True)
:param filename: file name to load from
:type filename: str
:param map_location: specifies how to remap storage locations
:type map_location: function, torch.device, string or a dict
"""
with open(filename, "rb") as input_file:
state = torch.load(input_file, map_location, weights_only=False)
self.set_state(state)
[docs] @contextmanager
def scope(self, state: Optional[StateDict] = None) -> Iterator[StateDict]:
"""
Context manager for using multiple parameter stores within the same process.
This is a thin wrapper around :meth:`get_state`, :meth:`clear`, and
:meth:`set_state`. For large models where memory space is limiting, you
may want to instead manually use :meth:`save`, :meth:`clear`, and
:meth:`load`.
Example usage::
param_store = pyro.get_param_store()
# Train multiple models, while avoiding param name conflicts.
with param_store.scope() as scope1:
# ...Train one model,guide pair...
with param_store.scope() as scope2:
# ...Train another model,guide pair...
# Now evaluate each, still avoiding name conflicts.
with param_store.scope(scope1): # loads the first model's scope
# ...evaluate the first model...
with param_store.scope(scope2): # loads the second model's scope
# ...evaluate the second model...
"""
if state is None:
state = {"params": {}, "constraints": {}}
old_state = self.get_state()
try:
self.clear()
self.set_state(state)
yield state
state.update(self.get_state())
finally:
self.clear()
self.set_state(old_state)
# used to create fully-formed param names, e.g. mymodule$$$mysubmodule.weight
_MODULE_NAMESPACE_DIVIDER = "$$$"
[docs]def param_with_module_name(pyro_name: str, param_name: str) -> str:
return _MODULE_NAMESPACE_DIVIDER.join([pyro_name, param_name])
[docs]def module_from_param_with_module_name(param_name: str) -> str:
return param_name.split(_MODULE_NAMESPACE_DIVIDER)[0]
[docs]def user_param_name(param_name: str) -> str:
if _MODULE_NAMESPACE_DIVIDER in param_name:
return param_name.split(_MODULE_NAMESPACE_DIVIDER)[1]
return param_name
[docs]def normalize_param_name(name: str) -> str:
return name.replace(_MODULE_NAMESPACE_DIVIDER, ".")