Source code for pyro.distributions.constraints

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from torch.distributions.constraints import *  # noqa F403

# isort: split

import torch
from torch.distributions.constraints import Constraint
from torch.distributions.constraints import __all__ as torch_constraints
from torch.distributions.constraints import (
    independent,
    lower_cholesky,
    positive,
    positive_definite,
)


# TODO move this upstream to torch.distributions
[docs]class _Integer(Constraint): """ Constrain to integers. """ is_discrete = True def check(self, value): return value % 1 == 0 def __repr__(self): return self.__class__.__name__[1:]
[docs]class _Sphere(Constraint): """ Constrain to the Euclidean sphere of any dimension. """ event_dim = 1 reltol = 10. # Relative to finfo.eps. def check(self, value): eps = torch.finfo(value.dtype).eps norm = torch.linalg.norm(value, dim=-1) error = (norm - 1).abs() return error < self.reltol * eps * value.size(-1) ** 0.5 def __repr__(self): return self.__class__.__name__[1:]
[docs]class _CorrMatrix(Constraint): """ Constrains to a correlation matrix. """ event_dim = 2 def check(self, value): # check for diagonal equal to 1 unit_variance = torch.all(torch.abs(torch.diagonal(value, dim1=-2, dim2=-1) - 1) < 1e-6, dim=-1) # TODO: fix upstream - positive_definite has an extra dimension in front of output shape return positive_definite.check(value) & unit_variance
[docs]class _OrderedVector(Constraint): """ Constrains to a real-valued tensor where the elements are monotonically increasing along the `event_shape` dimension. """ event_dim = 1 def check(self, value): if value.ndim == 0: return torch.tensor(False, device=value.device) elif value.shape[-1] == 1: return torch.ones_like(value[..., 0], dtype=bool) else: return torch.all(value[..., 1:] > value[..., :-1], dim=-1)
[docs]class _PositiveOrderedVector(Constraint): """ Constrains to a positive real-valued tensor where the elements are monotonically increasing along the `event_shape` dimension. """ def check(self, value): return ordered_vector.check(value) & independent(positive, 1).check(value)
[docs]class _SoftplusPositive(type(positive)): def __init__(self): super().__init__(lower_bound=0.0)
[docs]class _SoftplusLowerCholesky(type(lower_cholesky)): pass
corr_matrix = _CorrMatrix() integer = _Integer() ordered_vector = _OrderedVector() positive_ordered_vector = _PositiveOrderedVector() sphere = _Sphere() softplus_positive = _SoftplusPositive() softplus_lower_cholesky = _SoftplusLowerCholesky() corr_cholesky_constraint = corr_cholesky # noqa: F405 DEPRECATED __all__ = [ 'corr_cholesky_constraint', 'corr_matrix', 'integer', 'ordered_vector', 'positive_ordered_vector', 'softplus_lower_cholesky', 'softplus_positive', 'sphere', ] __all__.extend(torch_constraints) __all__ = sorted(set(__all__)) del torch_constraints # Create sphinx documentation. __doc__ = """ Pyro's constraints library extends :mod:`torch.distributions.constraints`. """ __doc__ += "\n".join([ """ {} ---------------------------------------------------------------- {} """.format( _name, "alias of :class:`torch.distributions.constraints.{}`".format(_name) if globals()[_name].__module__.startswith("torch") else ".. autoclass:: {}".format(_name if type(globals()[_name]) is type else type(globals()[_name]).__name__) ) for _name in sorted(__all__) ])