# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
"""
Pyro includes a class :class:`~pyro.nn.module.PyroModule`, a subclass of
:class:`torch.nn.Module`, whose attributes can be modified by Pyro effects. To
create a poutine-aware attribute, use either the :class:`PyroParam` struct or
the :class:`PyroSample` struct::
my_module = PyroModule()
my_module.x = PyroParam(torch.tensor(1.), constraint=constraints.positive)
my_module.y = PyroSample(dist.Normal(0, 1))
"""
import functools
import inspect
import warnings
import weakref
try:
from torch._jit_internal import _copy_to_script_wrapper
except ImportError:
warnings.warn(
"Cannot find torch._jit_internal._copy_to_script_wrapper", ImportWarning
)
# Fall back to trivial decorator.
def _copy_to_script_wrapper(fn):
return fn
from collections import OrderedDict
from dataclasses import dataclass
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
List,
NamedTuple,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import torch
from torch.distributions import constraints, transform_to
from typing_extensions import Concatenate, ParamSpec
import pyro
import pyro.params.param_store
from pyro.ops.provenance import detach_provenance
from pyro.poutine.runtime import _PYRO_PARAM_STORE
_MODULE_LOCAL_PARAMS: bool = False
_P = ParamSpec("_P")
_T = TypeVar("_T")
_PyroModule = TypeVar("_PyroModule", bound="PyroModule")
if TYPE_CHECKING:
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.params.param_store import StateDict
@pyro.settings.register("module_local_params", __name__, "_MODULE_LOCAL_PARAMS")
def _validate_module_local_params(value: bool) -> None:
assert isinstance(value, bool)
def _is_module_local_param_enabled() -> bool:
return pyro.settings.get("module_local_params") # type: ignore[no-any-return]
[docs]class PyroParam(NamedTuple):
"""
Declares a Pyro-managed learnable attribute of a :class:`PyroModule`,
similar to :func:`pyro.param <pyro.primitives.param>`.
This can be used either to set attributes of :class:`PyroModule`
instances::
assert isinstance(my_module, PyroModule)
my_module.x = PyroParam(torch.zeros(4)) # eager
my_module.y = PyroParam(lambda: torch.randn(4)) # lazy
my_module.z = PyroParam(torch.ones(4), # eager
constraint=constraints.positive,
event_dim=1)
or EXPERIMENTALLY as a decorator on lazy initialization properties::
class MyModule(PyroModule):
@PyroParam
def x(self):
return torch.zeros(4)
@PyroParam
def y(self):
return torch.randn(4)
@PyroParam(constraint=constraints.real, event_dim=1)
def z(self):
return torch.ones(4)
def forward(self):
return self.x + self.y + self.z # accessed like a @property
:param init_value: Either a tensor for eager initialization, a callable for
lazy initialization, or None for use as a decorator.
:type init_value: torch.Tensor or callable returning a torch.Tensor or None
:param constraint: torch constraint, defaults to ``constraints.real``.
:type constraint: ~torch.distributions.constraints.Constraint
:param int event_dim: (optional) number of rightmost dimensions unrelated
to baching. Dimension to the left of this will be considered batch
dimensions; if the param statement is inside a subsampled plate, then
corresponding batch dimensions of the parameter will be correspondingly
subsampled. If unspecified, all dimensions will be considered event
dims and no subsampling will be performed.
"""
init_value: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None
constraint: constraints.Constraint = constraints.real
event_dim: Optional[int] = None
# Support use as a decorator.
def __get__(
self, obj: Optional["PyroModule"], obj_type: Type["PyroModule"]
) -> "PyroParam":
assert issubclass(obj_type, PyroModule)
if obj is None:
return self
name = self.init_value.__name__ # type: ignore[union-attr]
if name not in obj.__dict__["_pyro_params"]:
init_value, constraint, event_dim = self
# bind method's self arg
init_value = functools.partial(init_value, obj) # type: ignore[arg-type,call-arg,misc,operator]
setattr(obj, name, PyroParam(init_value, constraint, event_dim))
value: PyroParam = obj.__getattr__(name)
return value
# Support decoration with optional kwargs, e.g. @PyroParam(event_dim=0).
def __call__(
self, init_value: Union[torch.Tensor, Callable[[], torch.Tensor]]
) -> "PyroParam":
assert self.init_value is None
return PyroParam(init_value, self.constraint, self.event_dim)
[docs]@dataclass(frozen=True)
class PyroSample:
"""
Declares a Pyro-managed random attribute of a :class:`PyroModule`, similar
to :func:`pyro.sample <pyro.primitives.sample>`.
This can be used either to set attributes of :class:`PyroModule`
instances::
assert isinstance(my_module, PyroModule)
my_module.x = PyroSample(Normal(0, 1)) # independent
my_module.y = PyroSample(lambda self: Normal(self.x, 1)) # dependent
my_module.z = PyroSample(lambda self: self.y ** 2) # deterministic dependent
or EXPERIMENTALLY as a decorator on lazy initialization methods::
class MyModule(PyroModule):
@PyroSample
def x(self):
return Normal(0, 1) # independent
@PyroSample
def y(self):
return Normal(self.x, 1) # dependent
@PyroSample
def z(self):
return self.y ** 2 # deterministic dependent
def forward(self):
return self.z # accessed like a @property
:param prior: distribution object or function that inputs the
:class:`PyroModule` instance ``self`` and returns a distribution
object or a deterministic value.
"""
prior: Union[
"TorchDistributionMixin",
Callable[["PyroModule"], "TorchDistributionMixin"],
Callable[["PyroModule"], torch.Tensor],
]
def __post_init__(self) -> None:
if not hasattr(self.prior, "sample"): # if not a distribution
assert 1 == sum(
1
for p in inspect.signature(self.prior).parameters.values()
if p.default is inspect.Parameter.empty
), "prior should take the single argument 'self'"
object.__setattr__(self, "name", getattr(self.prior, "__name__", None))
self.name: Optional[str]
if self.name is not None:
# Ensure decorated function is accessible for pickling.
self.prior.__name__ = "_pyro_prior_" + self.prior.__name__
qualname = self.prior.__qualname__.rsplit(".", 1)
qualname[-1] = self.prior.__name__
self.prior.__qualname__ = ".".join(qualname)
# Support use as a decorator.
def __get__(
self, obj: Optional["PyroModule"], obj_type: Type["PyroModule"]
) -> "PyroSample":
assert issubclass(obj_type, PyroModule)
if obj is None:
return self
if self.name is None:
for name in dir(obj_type):
if getattr(obj_type, name) is self:
self.name = name
break
else:
setattr(obj_type, self.prior.__name__, self.prior) # for pickling
obj.__dict__["_pyro_samples"].setdefault(self.name, self.prior)
assert self.name is not None
value: PyroSample = obj.__getattr__(self.name)
return value
def _make_name(prefix: str, name: str) -> str:
return "{}.{}".format(prefix, name) if prefix else name
def _unconstrain(
constrained_value: Union[torch.Tensor, Callable[[], torch.Tensor]],
constraint: constraints.Constraint,
) -> torch.nn.Parameter:
with torch.no_grad():
if callable(constrained_value):
constrained_value = constrained_value()
unconstrained_value = transform_to(constraint).inv(constrained_value.detach())
return torch.nn.Parameter(unconstrained_value)
class _Context:
"""
Sometimes-active cache for ``PyroModule.__call__()`` contexts.
"""
def __init__(self) -> None:
self.active = 0
self.cache: Dict[str, torch.Tensor] = {}
self.used = False
if _is_module_local_param_enabled():
self.param_state: "StateDict" = {"params": {}, "constraints": {}}
def __enter__(self) -> None:
if not self.active and _is_module_local_param_enabled():
self._param_ctx = pyro.get_param_store().scope(state=self.param_state)
self.param_state = self._param_ctx.__enter__()
self.active += 1
self.used = True
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.active -= 1
if not self.active:
self.cache.clear()
if _is_module_local_param_enabled():
self._param_ctx.__exit__(type, value, traceback)
del self._param_ctx
def get(self, name: str) -> Optional[torch.Tensor]:
if self.active:
return self.cache.get(name)
return None
def set(self, name: str, value: torch.Tensor) -> None:
if self.active:
self.cache[name] = value
def _get_pyro_params(
module: torch.nn.Module,
) -> Iterator[Tuple[str, Optional[torch.nn.Parameter]]]:
for name in module._parameters:
if name.endswith("_unconstrained"):
constrained_name = name[: -len("_unconstrained")]
if (
isinstance(module, PyroModule)
and constrained_name in module._pyro_params
):
yield constrained_name, getattr(module, constrained_name)
continue
yield name, module._parameters[name]
class _PyroModuleMeta(type):
_pyro_mixin_cache: Dict[Type[torch.nn.Module], Type["PyroModule"]] = {}
# Unpickling helper to create an empty object of type PyroModule[Module].
class _New:
def __init__(self, Module):
self.__class__ = PyroModule[Module]
def __getitem__(cls, Module: Type[torch.nn.Module]) -> Type["PyroModule"]:
assert isinstance(Module, type)
assert issubclass(Module, torch.nn.Module)
if issubclass(Module, PyroModule):
return Module
if Module is torch.nn.Module:
return PyroModule
if Module in _PyroModuleMeta._pyro_mixin_cache:
return _PyroModuleMeta._pyro_mixin_cache[Module]
bases = [
PyroModule[b] for b in Module.__bases__ if issubclass(b, torch.nn.Module)
]
class result(Module, *bases): # type: ignore[valid-type, misc]
# Unpickling helper to load an object of type PyroModule[Module].
def __reduce__(self):
state = getattr(self, "__getstate__", self.__dict__.copy)()
return _PyroModuleMeta._New, (Module,), state
result.__name__ = "Pyro" + Module.__name__
_PyroModuleMeta._pyro_mixin_cache[Module] = result
return result
[docs]class PyroModule(torch.nn.Module, metaclass=_PyroModuleMeta):
"""
Subclass of :class:`torch.nn.Module` whose attributes can be modified by
Pyro effects. Attributes can be set using helpers :class:`PyroParam` and
:class:`PyroSample` , and methods can be decorated by :func:`pyro_method` .
**Parameters**
To create a Pyro-managed parameter attribute, set that attribute using
either :class:`torch.nn.Parameter` (for unconstrained parameters) or
:class:`PyroParam` (for constrained parameters). Reading that attribute
will then trigger a :func:`pyro.param <pyro.primitives.param>` statement.
For example::
# Create Pyro-managed parameter attributes.
my_module = PyroModule()
my_module.loc = nn.Parameter(torch.tensor(0.))
my_module.scale = PyroParam(torch.tensor(1.),
constraint=constraints.positive)
# Read the attributes.
loc = my_module.loc # Triggers a pyro.param statement.
scale = my_module.scale # Triggers another pyro.param statement.
Note that, unlike normal :class:`torch.nn.Module` s, :class:`PyroModule` s
should not be registered with :func:`pyro.module <pyro.primitives.module>`
statements. :class:`PyroModule` s can contain other :class:`PyroModule` s
and normal :class:`torch.nn.Module` s. Accessing a normal
:class:`torch.nn.Module` attribute of a :class:`PyroModule` triggers a
:func:`pyro.module <pyro.primitives.module>` statement. If multiple
:class:`PyroModule` s appear in a single Pyro model or guide, they should
be included in a single root :class:`PyroModule` for that model.
:class:`PyroModule` s synchronize data with the param store at each
``setattr``, ``getattr``, and ``delattr`` event, based on the nested name
of an attribute:
- Setting ``mod.x = x_init`` tries to read ``x`` from the param store. If a
value is found in the param store, that value is copied into ``mod``
and ``x_init`` is ignored; otherwise ``x_init`` is copied into both
``mod`` and the param store.
- Reading ``mod.x`` tries to read ``x`` from the param store. If a
value is found in the param store, that value is copied into ``mod``;
otherwise ``mod``'s value is copied into the param store. Finally
``mod`` and the param store agree on a single value to return.
- Deleting ``del mod.x`` removes a value from both ``mod`` and the param
store.
Note two :class:`PyroModule` of the same name will both synchronize with
the global param store and thus contain the same data. When creating a
:class:`PyroModule`, then deleting it, then creating another with the same
name, the latter will be populated with the former's data from the param
store. To avoid this persistence, either ``pyro.clear_param_store()`` or
call :func:`clear` before deleting a :class:`PyroModule` .
:class:`PyroModule` s can be saved and loaded either directly using
:func:`torch.save` / :func:`torch.load` or indirectly using the param
store's :meth:`~pyro.params.param_store.ParamStoreDict.save` /
:meth:`~pyro.params.param_store.ParamStoreDict.load` . Note that
:func:`torch.load` will be overridden by any values in the param store, so
it is safest to ``pyro.clear_param_store()`` before loading.
**Samples**
To create a Pyro-managed random attribute, set that attribute using the
:class:`PyroSample` helper, specifying a prior distribution. Reading that
attribute will then trigger a :func:`pyro.sample <pyro.primitives.sample>`
statement. For example::
# Create Pyro-managed random attributes.
my_module.x = PyroSample(dist.Normal(0, 1))
my_module.y = PyroSample(lambda self: dist.Normal(self.loc, self.scale))
# Sample the attributes.
x = my_module.x # Triggers a pyro.sample statement.
y = my_module.y # Triggers one pyro.sample + two pyro.param statements.
Sampling is cached within each invocation of ``.__call__()`` or method
decorated by :func:`pyro_method` . Because sample statements can appear
only once in a Pyro trace, you should ensure that traced access to sample
attributes is wrapped in a single invocation of ``.__call__()`` or method
decorated by :func:`pyro_method` .
To make an existing module probabilistic, you can create a subclass and
overwrite some parameters with :class:`PyroSample` s::
class RandomLinear(nn.Linear, PyroModule): # used as a mixin
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features)
self.weight = PyroSample(
lambda self: dist.Normal(0, 1)
.expand([self.out_features,
self.in_features])
.to_event(2))
**Mixin classes**
:class:`PyroModule` can be used as a mixin class, and supports simple
syntax for dynamically creating mixins, for example the following are
equivalent::
# Version 1. create a named mixin class
class PyroLinear(nn.Linear, PyroModule):
pass
m.linear = PyroLinear(m, n)
# Version 2. create a dynamic mixin class
m.linear = PyroModule[nn.Linear](m, n)
This notation can be used recursively to create Bayesian modules, e.g.::
model = PyroModule[nn.Sequential](
PyroModule[nn.Linear](28 * 28, 100),
PyroModule[nn.Sigmoid](),
PyroModule[nn.Linear](100, 100),
PyroModule[nn.Sigmoid](),
PyroModule[nn.Linear](100, 10),
)
assert isinstance(model, nn.Sequential)
assert isinstance(model, PyroModule)
# Now we can be Bayesian about weights in the first layer.
model[0].weight = PyroSample(
prior=dist.Normal(0, 1).expand([28 * 28, 100]).to_event(2))
guide = AutoDiagonalNormal(model)
Note that ``PyroModule[...]`` does not recursively mix in
:class:`PyroModule` to submodules of the input ``Module``; hence we needed
to wrap each submodule of the ``nn.Sequential`` above.
:param str name: Optional name for a root PyroModule. This is ignored in
sub-PyroModules of another PyroModule.
"""
def __init__(self, name: str = "") -> None:
self._pyro_name = name
self._pyro_context = _Context() # shared among sub-PyroModules
self._pyro_params: OrderedDict[
str, Tuple[constraints.Constraint, Optional[int]]
] = OrderedDict()
self._pyro_samples: OrderedDict[str, PyroSample] = OrderedDict()
super().__init__()
[docs] def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
"""
Adds a child module to the current module.
"""
if isinstance(module, PyroModule):
module._pyro_set_supermodule(
_make_name(self._pyro_name, name), self._pyro_context
)
super().add_module(name, module)
[docs] def named_pyro_params(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""
Returns an iterator over PyroModule parameters, yielding both the
name of the parameter as well as the parameter itself.
:param str prefix: prefix to prepend to all parameter names.
:param bool recurse: if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
:returns: a generator which yields tuples containing the name and parameter
"""
gen = self._named_members(_get_pyro_params, prefix=prefix, recurse=recurse)
for elem in gen:
yield elem
def _pyro_set_supermodule(self, name: str, context: _Context) -> None:
if _is_module_local_param_enabled() and pyro.settings.get("validate_poutine"):
self._check_module_local_param_usage()
self._pyro_name = name
self._pyro_context = context
for key, value in self._modules.items():
if isinstance(value, PyroModule):
assert (
not value._pyro_context.used
), "submodule {} has executed outside of supermodule".format(name)
value._pyro_set_supermodule(_make_name(name, key), context)
def _pyro_get_fullname(self, name: str) -> str:
assert self.__dict__["_pyro_context"].used, "fullname is not yet defined"
return _make_name(self.__dict__["_pyro_name"], name)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
with self._pyro_context:
result = super().__call__(*args, **kwargs)
if (
pyro.settings.get("validate_poutine")
and not self._pyro_context.active
and _is_module_local_param_enabled()
):
self._check_module_local_param_usage()
return result
def _check_module_local_param_usage(self) -> None:
self_nn_params = set(id(p) for p in self.parameters())
self_pyro_params = set(
id(p if not hasattr(p, "unconstrained") else p.unconstrained())
for p in self._pyro_context.param_state["params"].values()
)
if not self_pyro_params <= self_nn_params:
raise NotImplementedError(
"Support for global pyro.param statements in PyroModules "
"with local param mode enabled is not yet implemented."
)
def __getattr__(self, name: str) -> Any:
# PyroParams trigger pyro.param statements.
if "_pyro_params" in self.__dict__:
_pyro_params = self.__dict__["_pyro_params"]
if name in _pyro_params:
constraint, event_dim = _pyro_params[name]
unconstrained_value = getattr(self, name + "_unconstrained")
if self._pyro_context.active and not _is_module_local_param_enabled():
fullname = self._pyro_get_fullname(name)
if fullname in _PYRO_PARAM_STORE:
if (
_PYRO_PARAM_STORE._params[fullname]
is not unconstrained_value
):
# Update PyroModule <--- ParamStore.
unconstrained_value = _PYRO_PARAM_STORE._params[fullname]
if not isinstance(unconstrained_value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
unconstrained_value = torch.nn.Parameter(
unconstrained_value
)
_PYRO_PARAM_STORE._params[fullname] = (
unconstrained_value
)
_PYRO_PARAM_STORE._param_to_name[
unconstrained_value
] = fullname
super().__setattr__(
name + "_unconstrained", unconstrained_value
)
else:
# Update PyroModule ---> ParamStore.
_PYRO_PARAM_STORE._constraints[fullname] = constraint
_PYRO_PARAM_STORE._params[fullname] = unconstrained_value
_PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname
return pyro.param(fullname, event_dim=event_dim)
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake param statement to ensure any handlers of pyro.param are applied,
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
constrained_value = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
return pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: constrained_value
)(
fullname,
constraint=constraint,
event_dim=event_dim,
name=fullname,
)
else: # Cannot determine supermodule and hence cannot compute fullname.
constrained_value = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
return constrained_value
# PyroSample trigger pyro.sample statements.
if "_pyro_samples" in self.__dict__:
_pyro_samples = self.__dict__["_pyro_samples"]
if name in _pyro_samples:
prior = _pyro_samples[name]
context = self._pyro_context
if context.active:
fullname = self._pyro_get_fullname(name)
value = context.get(fullname)
if value is None:
if not hasattr(prior, "sample"): # if not a distribution
prior = prior(self)
value = (
pyro.deterministic(fullname, prior)
if isinstance(prior, torch.Tensor)
else pyro.sample(fullname, prior)
)
context.set(fullname, value)
return value
else: # Cannot determine supermodule and hence cannot compute fullname.
if not hasattr(prior, "sample"): # if not a distribution
prior = prior(self)
return prior if isinstance(prior, torch.Tensor) else prior()
result = super().__getattr__(name)
# Regular nn.Parameters trigger pyro.param statements.
if isinstance(result, torch.nn.Parameter) and not name.endswith(
"_unconstrained"
):
if self._pyro_context.active and not _is_module_local_param_enabled():
pyro.param(self._pyro_get_fullname(name), result)
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake param statement to ensure any handlers of pyro.param are applied,
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
pyro.poutine.runtime.effectful(type="param")(lambda *_, **__: result)(
fullname, result, constraint=constraints.real, name=fullname
)
if isinstance(result, torch.nn.Module):
if isinstance(result, PyroModule):
if not result._pyro_name:
# Update sub-PyroModules that were converted from nn.Modules in-place.
result._pyro_set_supermodule(
_make_name(self._pyro_name, name), self._pyro_context
)
else:
# Regular nn.Modules trigger pyro.module statements.
if self._pyro_context.active and not _is_module_local_param_enabled():
pyro.module(self._pyro_get_fullname(name), result)
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake module statement to ensure any handlers of pyro.module are applied,
# even though we don't use the contents of the local parameter store
fullname_module = self._pyro_get_fullname(name)
for param_name, param_value in result.named_parameters():
fullname_param = pyro.params.param_store.param_with_module_name(
fullname_module, param_name
)
pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: param_value
)(
fullname_param,
param_value,
constraint=constraints.real,
name=fullname_param,
)
return result
def __setattr__(
self,
name: str,
value: Any,
) -> None:
if isinstance(value, PyroModule):
# Create a new sub PyroModule, overwriting any old value.
try:
delattr(self, name)
except AttributeError:
pass
self.add_module(name, value)
return
if isinstance(value, PyroParam):
# Create a new PyroParam, overwriting any old value.
try:
delattr(self, name)
except AttributeError:
pass
constrained_value, constraint, event_dim = value
assert constrained_value is not None
self._pyro_params[name] = constraint, event_dim
if self._pyro_context.active and not _is_module_local_param_enabled():
fullname = self._pyro_get_fullname(name)
pyro.param(
fullname,
constrained_value,
constraint=constraint,
event_dim=event_dim,
)
constrained_value = detach_provenance(pyro.param(fullname))
unconstrained_value: torch.Tensor = constrained_value.unconstrained() # type: ignore[attr-defined]
if not isinstance(unconstrained_value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
unconstrained_value = torch.nn.Parameter(unconstrained_value)
_PYRO_PARAM_STORE._params[fullname] = unconstrained_value
_PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake param statement to ensure any handlers of pyro.param are applied,
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
constrained_value = detach_provenance(
pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: (
constrained_value()
if callable(constrained_value)
else constrained_value
)
)(
fullname,
constraint=constraint,
event_dim=event_dim,
name=fullname,
)
)
unconstrained_value = _unconstrain(constrained_value, constraint)
else: # Cannot determine supermodule and hence cannot compute fullname.
unconstrained_value = _unconstrain(constrained_value, constraint)
super().__setattr__(name + "_unconstrained", unconstrained_value)
return
if isinstance(value, torch.nn.Parameter):
# Create a new nn.Parameter, overwriting any old value.
try:
delattr(self, name)
except AttributeError:
pass
if self._pyro_context.active and not _is_module_local_param_enabled():
fullname = self._pyro_get_fullname(name)
value = pyro.param(fullname, value)
if not isinstance(value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
value = torch.nn.Parameter(detach_provenance(value))
_PYRO_PARAM_STORE._params[fullname] = value
_PYRO_PARAM_STORE._param_to_name[value] = fullname
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake param statement to ensure any handlers of pyro.param are applied,
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
value = detach_provenance(
pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: value
)(fullname, value, constraint=constraints.real, name=fullname)
)
super().__setattr__(name, value)
return
if isinstance(value, torch.Tensor):
if name in self._pyro_params:
# Update value of an existing PyroParam.
constraint, event_dim = self._pyro_params[name]
unconstrained_value = getattr(self, name + "_unconstrained")
with torch.no_grad():
unconstrained_value.data = transform_to(constraint).inv(
value.detach()
)
return
if isinstance(value, PyroSample):
# Create a new PyroSample, overwriting any old value.
try:
delattr(self, name)
except AttributeError:
pass
_pyro_samples = self.__dict__["_pyro_samples"]
_pyro_samples[name] = value.prior
return
super().__setattr__(name, value)
def __delattr__(self, name: str) -> None:
if name in self._parameters:
del self._parameters[name]
if self._pyro_context.used:
fullname = self._pyro_get_fullname(name)
if fullname in _PYRO_PARAM_STORE:
# Update PyroModule ---> ParamStore.
del _PYRO_PARAM_STORE[fullname]
return
if name in self._pyro_params:
delattr(self, name + "_unconstrained")
del self._pyro_params[name]
if self._pyro_context.used:
fullname = self._pyro_get_fullname(name)
if fullname in _PYRO_PARAM_STORE:
# Update PyroModule ---> ParamStore.
del _PYRO_PARAM_STORE[fullname]
return
if name in self._pyro_samples:
del self._pyro_samples[name]
return
if name in self._modules:
del self._modules[name]
if self._pyro_context.used:
fullname = self._pyro_get_fullname(name)
for p in list(_PYRO_PARAM_STORE.keys()):
if p.startswith(fullname):
del _PYRO_PARAM_STORE[p]
return
super().__delattr__(name)
def __getstate__(self) -> Dict[str, Any]:
# Remove weakrefs in preparation for pickling.
for param in self.parameters(recurse=True):
param.__dict__.pop("unconstrained", None)
return getattr(super(), "__getstate__", self.__dict__.copy)()
[docs]def pyro_method(
fn: Callable[Concatenate[_PyroModule, _P], _T]
) -> Callable[Concatenate[_PyroModule, _P], _T]:
"""
Decorator for top-level methods of a :class:`PyroModule` to enable pyro
effects and cache ``pyro.sample`` statements.
This should be applied to all public methods that read Pyro-managed
attributes, but is not needed for ``.forward()``.
"""
@functools.wraps(fn)
def cached_fn(self: _PyroModule, *args: _P.args, **kwargs: _P.kwargs) -> _T:
with self._pyro_context:
return fn(self, *args, **kwargs)
return cached_fn
[docs]def clear(mod: PyroModule) -> None:
"""
Removes data from both a :class:`PyroModule` and the param store.
:param PyroModule mod: A module to clear.
"""
assert isinstance(mod, PyroModule)
for name in list(mod._pyro_params):
delattr(mod, name)
for name in list(mod._parameters):
delattr(mod, name)
for name in list(mod._modules):
delattr(mod, name)
[docs]def to_pyro_module_(m: torch.nn.Module, recurse: bool = True) -> None:
"""
Converts an ordinary :class:`torch.nn.Module` instance to a
:class:`PyroModule` **in-place**.
This is useful for adding Pyro effects to third-party modules: no
third-party code needs to be modified. For example::
model = nn.Sequential(
nn.Linear(28 * 28, 100),
nn.Sigmoid(),
nn.Linear(100, 100),
nn.Sigmoid(),
nn.Linear(100, 10),
)
to_pyro_module_(model)
assert isinstance(model, PyroModule[nn.Sequential])
assert isinstance(model[0], PyroModule[nn.Linear])
# Now we can attempt to be fully Bayesian:
for m in model.modules():
for name, value in list(m.named_parameters(recurse=False)):
setattr(m, name, PyroSample(prior=dist.Normal(0, 1)
.expand(value.shape)
.to_event(value.dim())))
guide = AutoDiagonalNormal(model)
:param torch.nn.Module m: A module instance.
:param bool recurse: Whether to convert submodules to :class:`PyroModules` .
"""
if not isinstance(m, torch.nn.Module):
raise TypeError("Expected an nn.Module instance but got a {}".format(type(m)))
if isinstance(m, PyroModule):
if recurse:
for name, module in list(m._modules.items()):
if TYPE_CHECKING:
assert module is not None
to_pyro_module_(module)
setattr(m, name, module)
return
# Change m's type in-place.
m.__class__ = PyroModule[m.__class__]
assert isinstance(m, PyroModule)
m._pyro_name = ""
m._pyro_context = _Context()
m._pyro_params = OrderedDict()
m._pyro_samples = OrderedDict()
# Reregister parameters and submodules.
for name, param in list(m._parameters.items()):
setattr(m, name, param)
for name, module in list(m._modules.items()):
if recurse:
if TYPE_CHECKING:
assert module is not None
to_pyro_module_(module)
setattr(m, name, module)
# The following descriptor disables the ._flat_weights cache of
# torch.nn.RNNBase, forcing recomputation on each access of the ._flat_weights
# attribute. This is required if any attribute is set to a PyroParam or
# PyroSample. For motivation, see https://github.com/pyro-ppl/pyro/issues/2390
class _FlatWeightsDescriptor:
def __get__(
self,
obj: Optional[torch.nn.RNNBase],
obj_type: Optional[Type[torch.nn.RNNBase]] = None,
) -> Union["_FlatWeightsDescriptor", List]:
if obj is None:
return self
return [getattr(obj, name) for name in obj._flat_weights_names]
def __set__(self, obj: object, value: Any) -> None:
pass # Ignore value.
PyroModule[torch.nn.RNNBase]._flat_weights = _FlatWeightsDescriptor() # type: ignore[attr-defined]
# pyro module list
# using pyro.nn.PyroModule[torch.nn.ModuleList] can cause issues when
# slice-indexing nested PyroModuleLists, so we define a separate PyroModuleList
# class that overwrites the __getitem__ method to return a torch.nn.ModuleList
# to not use self.__class__ in __getitem__, as that would call the
# PyroModule.__init__ without the parent module context, leading to a loss
# of the parent module's _pyro_name, and eventually, errors during sampling
# as parameter names may not be unique anymore
# The scenario is rare but happend.
# The fix could not be applied in torch directly, which is why we have to deal
# with it here, see https://github.com/pytorch/pytorch/issues/121008
[docs]class PyroModuleList(torch.nn.ModuleList, PyroModule):
def __init__(self, modules):
super().__init__(modules)
@_copy_to_script_wrapper
def __getitem__(
self, idx: Union[int, slice]
) -> Union[torch.nn.Module, "PyroModuleList"]:
if isinstance(idx, slice):
# return self.__class__(list(self._modules.values())[idx])
return torch.nn.ModuleList(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]
_PyroModuleMeta._pyro_mixin_cache[torch.nn.ModuleList] = PyroModuleList