```# Copyright (c) 2017-2019 Uber Technologies, Inc.

from torch.distributions.constraints import *  # noqa F403

# isort: split

import torch
from torch.distributions.constraints import Constraint
from torch.distributions.constraints import __all__ as torch_constraints
from torch.distributions.constraints import (
independent,
lower_cholesky,
positive,
positive_definite,
)

# TODO move this upstream to torch.distributions
[docs]class _Integer(Constraint):
"""
Constrain to integers.
"""

is_discrete = True

def check(self, value):
return value % 1 == 0

def __repr__(self):
return self.__class__.__name__[1:]

[docs]class _Sphere(Constraint):
"""
Constrain to the Euclidean sphere of any dimension.
"""

event_dim = 1
reltol = 10.0  # Relative to finfo.eps.

def check(self, value):
eps = torch.finfo(value.dtype).eps
norm = torch.linalg.norm(value, dim=-1)
error = (norm - 1).abs()
return error < self.reltol * eps * value.size(-1) ** 0.5

def __repr__(self):
return self.__class__.__name__[1:]

[docs]class _CorrMatrix(Constraint):
"""
Constrains to a correlation matrix.
"""

event_dim = 2

def check(self, value):
# check for diagonal equal to 1
unit_variance = torch.all(
torch.abs(torch.diagonal(value, dim1=-2, dim2=-1) - 1) < 1e-6, dim=-1
)
# TODO: fix upstream - positive_definite has an extra dimension in front of output shape
return positive_definite.check(value) & unit_variance

[docs]class _OrderedVector(Constraint):
"""
Constrains to a real-valued tensor where the elements are monotonically
increasing along the `event_shape` dimension.
"""

event_dim = 1

def check(self, value):
if value.ndim == 0:
elif value.shape[-1] == 1:
else:

[docs]class _PositiveOrderedVector(Constraint):
"""
Constrains to a positive real-valued tensor where the elements are monotonically
increasing along the `event_shape` dimension.
"""

def check(self, value):
return ordered_vector.check(value) & independent(positive, 1).check(value)

[docs]class _SoftplusPositive(type(positive)):
def __init__(self):
super().__init__(lower_bound=0.0)

[docs]class _SoftplusLowerCholesky(type(lower_cholesky)):
pass

[docs]class _UnitLowerCholesky(Constraint):
"""
Constrain to lower-triangular square matrices with all ones diagonals.
"""

event_dim = 2

def check(self, value):
value_tril = value.tril()
lower_triangular = (
(value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
)

ones_diagonal = (value.diagonal(dim1=-2, dim2=-1) == 1).min(-1)[0]
return lower_triangular & ones_diagonal

corr_matrix = _CorrMatrix()
integer = _Integer()
ordered_vector = _OrderedVector()
positive_ordered_vector = _PositiveOrderedVector()
sphere = _Sphere()
softplus_positive = _SoftplusPositive()
softplus_lower_cholesky = _SoftplusLowerCholesky()
unit_lower_cholesky = _UnitLowerCholesky()
corr_cholesky_constraint = corr_cholesky  # noqa: F405 DEPRECATED

__all__ = [
"corr_cholesky_constraint",
"corr_matrix",
"integer",
"ordered_vector",
"positive_ordered_vector",
"softplus_lower_cholesky",
"softplus_positive",
"sphere",
"unit_lower_cholesky",
]

__all__.extend(torch_constraints)
__all__ = sorted(set(__all__))
del torch_constraints

# Create sphinx documentation.
__doc__ = """
Pyro's constraints library extends
:mod:`torch.distributions.constraints`.
"""
__doc__ += "\n".join(
[
"""
{}
----------------------------------------------------------------
{}
""".format(
_name,
"alias of :class:`torch.distributions.constraints.{}`".format(_name)
if globals()[_name].__module__.startswith("torch")
else ".. autoclass:: {}".format(
_name
if type(globals()[_name]) is type
else type(globals()[_name]).__name__
),
)
for _name in sorted(__all__)
]
)
```