Source code for pyro.distributions.constraints

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions.constraints import *  # noqa F403
from torch.distributions.constraints import Constraint
from torch.distributions.constraints import __all__ as torch_constraints
from torch.distributions.constraints import lower_cholesky


# TODO move this upstream to torch.distributions
[docs]class IndependentConstraint(Constraint): """ Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`, so that an event is valid only if all its independent entries are valid. :param torch.distributions.constraints.Constraint base_constraint: A base constraint whose entries are incidentally independent. :param int reinterpreted_batch_ndims: The number of extra event dimensions that will be considered dependent. """ def __init__(self, base_constraint, reinterpreted_batch_ndims): self.base_constraint = base_constraint self.reinterpreted_batch_ndims = reinterpreted_batch_ndims def check(self, value): result = self.base_constraint.check(value) result = result.reshape(result.shape[:result.dim() - self.reinterpreted_batch_ndims] + (-1,)) result = result.min(-1)[0] return result
# TODO move this upstream to torch.distributions class _Integer(Constraint): """ Constrain to integers. """ def check(self, value): return value % 1 == 0 def __repr__(self): return self.__class__.__name__[1:] class _CorrCholesky(Constraint): """ Constrains to lower-triangular square matrices with positive diagonals and Euclidean norm of each row is 1, such that `torch.mm(omega, omega.t())` will have unit diagonal. """ def check(self, value): unit_norm_row = (value.norm(dim=-1).sub(1) < 1e-4).min(-1)[0] return lower_cholesky.check(value) & unit_norm_row class _OrderedVector(Constraint): """ Constrains to a real-valued tensor where the elements are monotonically increasing along the `event_shape` dimension. """ def check(self, value): if value.ndim == 0: return torch.tensor(False, device=value.device) elif value.shape[-1] == 1: return torch.ones_like(value[..., 0], dtype=bool) else: return torch.all(value[..., 1:] > value[..., :-1], dim=-1) corr_cholesky_constraint = _CorrCholesky() integer = _Integer() ordered_vector = _OrderedVector() __all__ = [ 'IndependentConstraint', 'corr_cholesky_constraint', 'integer', 'ordered_vector', ] __all__.extend(torch_constraints) __all__ = sorted(set(__all__)) del torch_constraints # Create sphinx documentation. __doc__ = """ Pyro's constraints library extends :mod:`torch.distributions.constraints`. """ __doc__ += "\n".join([ """ {} ---------------------------------------------------------------- {} """.format( _name, "alias of :class:`torch.distributions.constraints.{}`".format(_name) if globals()[_name].__module__.startswith("torch") else ".. autoclass:: {}".format(_name) ) for _name in sorted(__all__) ])