Source code for pyro.distributions.transforms.normalize

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import numbers

from torch.distributions.transforms import Transform

from pyro.ops.tensor_utils import safe_normalize

from .. import constraints


[docs]class Normalize(Transform): """ Safely project a vector onto the sphere wrt the ``p`` norm. This avoids the singularity at zero by mapping to the vector ``[1, 0, 0, ..., 0]``. """ domain = constraints.real_vector codomain = constraints.sphere bijective = False def __init__(self, p=2, cache_size=0): assert isinstance(p, numbers.Number) assert p >= 0 self.p = p super().__init__(cache_size=cache_size) def __eq__(self, other): return type(self) == type(other) and self.p == other.p def _call(self, x): return safe_normalize(x, p=self.p) def _inverse(self, y): return y
[docs] def with_cache(self, cache_size=1): if self._cache_size == cache_size: return self return Normalize(self.p, cache_size=cache_size)