Source code for pyro.distributions.torch_transform
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
[docs]class TransformModule(torch.distributions.Transform, torch.nn.Module):
"""
Transforms with learnable parameters such as normalizing flows should inherit from this class rather
than `Transform` so they are also a subclass of `nn.Module` and inherit all the useful methods of that class.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __hash__(self):
return super(torch.nn.Module, self).__hash__()
[docs]class ComposeTransformModule(torch.distributions.ComposeTransform, torch.nn.ModuleList):
"""
This allows us to use a list of `TransformModule` in the same way as
:class:`~torch.distributions.transform.ComposeTransform`. This is needed
so that transform parameters are automatically registered by Pyro's param
store when used in :class:`~pyro.nn.module.PyroModule` instances.
"""
def __init__(self, parts, cache_size=0):
super().__init__(parts, cache_size=cache_size)
for part in parts:
if isinstance(part, torch.nn.Module):
self.append(part)
def __hash__(self):
return super(torch.nn.Module, self).__hash__()
[docs] def with_cache(self, cache_size=1):
if cache_size == self._cache_size:
return self
return ComposeTransformModule(self.parts, cache_size=cache_size)