# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property
from pyro.distributions.constraints import IndependentConstraint
from pyro.distributions.torch_distribution import TorchDistributionMixin
# This overloads .log_prob() and .enumerate_support() to speed up evaluating
# log_prob on the support of this variable: we can completely avoid tensor ops
# and merely reshape the self.logits tensor. This is especially important for
# Pyro models that use enumeration.
[docs]class Categorical(torch.distributions.Categorical, TorchDistributionMixin):
def log_prob(self, value):
if getattr(value, '_pyro_categorical_support', None) == id(self):
# Assume value is a reshaped torch.arange(event_shape[0]).
# In this case we can call .reshape() rather than torch.gather().
if not torch._C._get_tracing_state():
if self._validate_args:
self._validate_sample(value)
assert value.size(0) == self.logits.size(-1)
logits = self.logits
if logits.dim() <= value.dim():
logits = logits.reshape((1,) * (1 + value.dim() - logits.dim()) + logits.shape)
if not torch._C._get_tracing_state():
assert logits.size(-1 - value.dim()) == 1
return logits.transpose(-1 - value.dim(), -1).squeeze(-1)
return super(Categorical, self).log_prob(value)
def enumerate_support(self, expand=True):
result = super(Categorical, self).enumerate_support(expand=expand)
if not expand:
result._pyro_categorical_support = id(self)
return result
[docs]class MultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributionMixin):
support = IndependentConstraint(constraints.real, 1) # TODO move upstream
# TODO: remove this in the PyTorch release > 1.4.0
@lazy_property
def precision_matrix(self):
identity = torch.eye(self.loc.size(-1), device=self.loc.device, dtype=self.loc.dtype)
return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(
self._batch_shape + self._event_shape + self._event_shape)
[docs]class Independent(torch.distributions.Independent, TorchDistributionMixin):
@constraints.dependent_property
def support(self):
return IndependentConstraint(self.base_dist.support, self.reinterpreted_batch_ndims)
@property
def _validate_args(self):
return self.base_dist._validate_args
@_validate_args.setter
def _validate_args(self, value):
self.base_dist._validate_args = value
# Programmatically load all distributions from PyTorch.
__all__ = []
for _name, _Dist in torch.distributions.__dict__.items():
if not isinstance(_Dist, type):
continue
if not issubclass(_Dist, torch.distributions.Distribution):
continue
if _Dist is torch.distributions.Distribution:
continue
try:
_PyroDist = locals()[_name]
except KeyError:
_PyroDist = type(_name, (_Dist, TorchDistributionMixin), {})
_PyroDist.__module__ = __name__
locals()[_name] = _PyroDist
_PyroDist.__doc__ = '''
Wraps :class:`{}.{}` with
:class:`~pyro.distributions.torch_distribution.TorchDistributionMixin`.
'''.format(_Dist.__module__, _Dist.__name__)
__all__.append(_name)
# Create sphinx documentation.
__doc__ = '\n\n'.join([
'''
{0}
----------------------------------------------------------------
.. autoclass:: pyro.distributions.{0}
'''.format(_name)
for _name in sorted(__all__)
])