# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import math
import re
import textwrap
import torch
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.distributions.util import broadcast_shape, sum_rightmost
from pyro.ops.special import log_binomial
from .. import settings
from . import constraints
def _clamp_by_zero(x):
# works like clamp(x, min=0) but has grad at 0 is 0.5
return (x.clamp(min=0) + x - x.clamp(max=0)) / 2
[docs]class Beta(torch.distributions.Beta, TorchDistributionMixin):
def conjugate_update(self, other):
"""
EXPERIMENTAL.
"""
assert isinstance(other, Beta)
concentration1 = self.concentration1 + other.concentration1 - 1
concentration0 = self.concentration0 + other.concentration0 - 1
updated = Beta(concentration1, concentration0)
def _log_normalizer(d):
x = d.concentration1
y = d.concentration0
return (x + y).lgamma() - x.lgamma() - y.lgamma()
log_normalizer = (
_log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated)
)
return updated, log_normalizer
[docs]class Binomial(torch.distributions.Binomial, TorchDistributionMixin):
# EXPERIMENTAL threshold on total_count above which sampling will use a
# clamped Poisson approximation for Binomial samples. This is useful for
# sampling very large populations.
approx_sample_thresh = math.inf
# EXPERIMENTAL If set to a positive value, the .log_prob() method will use
# a shifted Sterling's approximation to the Beta function, reducing
# computational cost from 3 lgamma() evaluations to 4 log() evaluations
# plus arithmetic. Recommended values are between 0.1 and 0.01.
approx_log_prob_tol = 0.0
def sample(self, sample_shape=torch.Size()):
if self.approx_sample_thresh < math.inf:
exact = self.total_count <= self.approx_sample_thresh
if not exact.all():
# Approximate large counts with a moment-matched clamped Poisson.
with torch.no_grad():
shape = self._extended_shape(sample_shape)
p = self.probs
q = 1 - self.probs
mean = torch.min(p, q) * self.total_count
variance = p * q * self.total_count
shift = (mean - variance).round()
result = torch.poisson(variance.expand(shape))
result = torch.min(result + shift, self.total_count)
sample = torch.where(p < q, result, self.total_count - result)
# Draw exact samples for remaining items.
if exact.any():
total_count = torch.where(
exact, self.total_count, torch.zeros_like(self.total_count)
)
exact_sample = torch.distributions.Binomial(
total_count, self.probs, validate_args=False
).sample(sample_shape)
sample = torch.where(exact, exact_sample, sample)
return sample
return super().sample(sample_shape)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
n = self.total_count
k = value
# k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
# (case logit < 0) = k * logit - n * log1p(e^logit)
# (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
# = k * logit - n * logit - n * log1p(e^-logit)
# (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
normalize_term = n * (
_clamp_by_zero(self.logits) + self.logits.abs().neg().exp().log1p()
)
return (
k * self.logits
- normalize_term
+ log_binomial(n, k, tol=self.approx_log_prob_tol)
)
@settings.register(
"binomial_approx_sample_thresh", __name__, "Binomial.approx_sample_thresh"
)
def _validate_thresh(thresh):
assert isinstance(thresh, float)
assert 0 < thresh
@settings.register(
"binomial_approx_log_prob_tol", __name__, "Binomial.approx_log_prob_tol"
)
def _validate_tol(tol):
assert isinstance(tol, float)
assert 0 <= tol
# This overloads .log_prob() and .enumerate_support() to speed up evaluating
# log_prob on the support of this variable: we can completely avoid tensor ops
# and merely reshape the self.logits tensor. This is especially important for
# Pyro models that use enumeration.
[docs]class Categorical(torch.distributions.Categorical, TorchDistributionMixin):
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
def log_prob(self, value):
if getattr(value, "_pyro_categorical_support", None) == id(self):
# Assume value is a reshaped torch.arange(event_shape[0]).
# In this case we can call .reshape() rather than torch.gather().
if not torch._C._get_tracing_state():
if self._validate_args:
self._validate_sample(value)
assert value.size(0) == self.logits.size(-1)
logits = self.logits
if logits.dim() <= value.dim():
logits = logits.reshape(
(1,) * (1 + value.dim() - logits.dim()) + logits.shape
)
if not torch._C._get_tracing_state():
assert logits.size(-1 - value.dim()) == 1
return logits.transpose(-1 - value.dim(), -1).squeeze(-1)
return super().log_prob(value)
def enumerate_support(self, expand=True):
result = super().enumerate_support(expand=expand)
if not expand:
result._pyro_categorical_support = id(self)
return result
[docs]class Dirichlet(torch.distributions.Dirichlet, TorchDistributionMixin):
@staticmethod
def infer_shapes(concentration):
batch_shape = concentration[:-1]
event_shape = concentration[-1:]
return batch_shape, event_shape
def conjugate_update(self, other):
"""
EXPERIMENTAL.
"""
assert isinstance(other, Dirichlet)
concentration = self.concentration + other.concentration - 1
updated = Dirichlet(concentration)
def _log_normalizer(d):
c = d.concentration
return c.sum(-1).lgamma() - c.lgamma().sum(-1)
log_normalizer = (
_log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated)
)
return updated, log_normalizer
[docs]class Gamma(torch.distributions.Gamma, TorchDistributionMixin):
def conjugate_update(self, other):
"""
EXPERIMENTAL.
"""
assert isinstance(other, Gamma)
concentration = self.concentration + other.concentration - 1
rate = self.rate + other.rate
updated = Gamma(concentration, rate)
def _log_normalizer(d):
c = d.concentration
return d.rate.log() * c - c.lgamma()
log_normalizer = (
_log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated)
)
return updated, log_normalizer
[docs]class Geometric(torch.distributions.Geometric, TorchDistributionMixin):
# TODO: move upstream
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return (-value - 1) * torch.nn.functional.softplus(self.logits) + self.logits
[docs]class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin):
def __init__(self, loc, scale, validate_args=None):
base_dist = Normal(loc, scale)
# This differs from torch.distributions.LogNormal only in that base_dist is
# a pyro.distributions.Normal rather than a torch.distributions.Normal.
super(torch.distributions.LogNormal, self).__init__(
base_dist,
torch.distributions.transforms.ExpTransform(),
validate_args=validate_args,
)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LogNormal, _instance)
return super(torch.distributions.LogNormal, self).expand(
batch_shape, _instance=new
)
[docs]class LowRankMultivariateNormal(
torch.distributions.LowRankMultivariateNormal, TorchDistributionMixin
):
@staticmethod
def infer_shapes(loc, cov_factor, cov_diag):
event_shape = loc[-1:]
batch_shape = broadcast_shape(loc[:-1], cov_factor[:-2], cov_diag[:-1])
return batch_shape, event_shape
[docs]class MultivariateNormal(
torch.distributions.MultivariateNormal, TorchDistributionMixin
):
@staticmethod
def infer_shapes(
loc, covariance_matrix=None, precision_matrix=None, scale_tril=None
):
batch_shape, event_shape = loc[:-1], loc[-1:]
for matrix in [covariance_matrix, precision_matrix, scale_tril]:
if matrix is not None:
batch_shape = broadcast_shape(batch_shape, matrix[:-2])
return batch_shape, event_shape
[docs]class Multinomial(torch.distributions.Multinomial, TorchDistributionMixin):
def infer_shapes(total_count=None, probs=None, logits=None):
tensor = probs if logits is None else logits
batch_shape, event_shape = tensor[:-1], tensor[-1:]
if isinstance(total_count, tuple):
batch_shape = broadcast_shape(batch_shape, total_count)
return batch_shape, event_shape
[docs]class Normal(torch.distributions.Normal, TorchDistributionMixin):
pass
[docs]class OneHotCategorical(torch.distributions.OneHotCategorical, TorchDistributionMixin):
@staticmethod
def infer_shapes(probs=None, logits=None):
tensor = probs if logits is None else logits
event_shape = tensor[-1:]
batch_shape = tensor[:-1]
return batch_shape, event_shape
[docs]class Poisson(torch.distributions.Poisson, TorchDistributionMixin):
def __init__(self, rate, *, is_sparse=False, validate_args=None):
self.is_sparse = is_sparse
super().__init__(rate, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Poisson, _instance)
new = super().expand(batch_shape, _instance=new)
new.is_sparse = self.is_sparse
return new
def log_prob(self, value):
if not self.is_sparse:
return super().log_prob(value)
if self._validate_args:
self._validate_sample(value)
rate, value, nonzero = torch.broadcast_tensors(self.rate, value, value > 0)
sparse_rate = rate[nonzero]
sparse_value = value[nonzero]
return (
torch.zeros_like(rate).masked_scatter(
nonzero,
(sparse_rate.log() * sparse_value) - (sparse_value + 1).lgamma(),
)
- rate
)
[docs]class Independent(torch.distributions.Independent, TorchDistributionMixin):
@staticmethod
def infer_shapes(**kwargs):
raise NotImplementedError
@property
def _validate_args(self):
return self.base_dist._validate_args
@_validate_args.setter
def _validate_args(self, value):
self.base_dist._validate_args = value
def conjugate_update(self, other):
"""
EXPERIMENTAL.
"""
n = self.reintepreted_batch_ndims
updated, log_normalizer = self.base_dist.conjugate_update(other.to_event(-n))
updated = updated.to_event(n)
log_normalizer = sum_rightmost(log_normalizer, n)
return updated, log_normalizer
def _cat_docstrings(*docstrings):
result = "\n".join(textwrap.dedent(s.lstrip("\n")) for s in docstrings)
result = re.sub("\n\n+", "\n\n", result)
# Drop torch-specific lines.
result = "".join(
line for line in result.splitlines(keepends=True) if "xdoctest" not in line
)
return result
# Add static imports to help mypy.
__all__ = [ # noqa: F822
"Bernoulli",
"Beta",
"Binomial",
"Categorical",
"Cauchy",
"Chi2",
"ContinuousBernoulli",
"Dirichlet",
"ExponentialFamily",
"Exponential",
"FisherSnedecor",
"Gamma",
"Geometric",
"Gumbel",
"HalfCauchy",
"HalfNormal",
"Independent",
"Kumaraswamy",
"Laplace",
"LKJCholesky",
"LogNormal",
"LogisticNormal",
"LowRankMultivariateNormal",
"MixtureSameFamily",
"Multinomial",
"MultivariateNormal",
"NegativeBinomial",
"Normal",
"OneHotCategorical",
"OneHotCategoricalStraightThrough",
"Pareto",
"Poisson",
"RelaxedBernoulli",
"RelaxedOneHotCategorical",
"StudentT",
"TransformedDistribution",
"Uniform",
"VonMises",
"Weibull",
"Wishart",
]
# Programmatically load all distributions from PyTorch,
# updating __all__ to include any new distributions.
for _name, _Dist in torch.distributions.__dict__.items():
if not isinstance(_Dist, type):
continue
if not issubclass(_Dist, torch.distributions.Distribution):
continue
if _Dist is torch.distributions.Distribution:
continue
try:
_PyroDist = locals()[_name]
except KeyError:
_PyroDist = type(_name, (_Dist, TorchDistributionMixin), {})
_PyroDist.__module__ = __name__
locals()[_name] = _PyroDist
_PyroDist.__doc__ = """
Wraps :class:`{}.{}` with
:class:`~pyro.distributions.torch_distribution.TorchDistributionMixin`.
""".format(
_Dist.__module__, _Dist.__name__
)
_PyroDist.__doc__ = _cat_docstrings(_PyroDist.__doc__, _Dist.__doc__)
__all__.append(_name)
__all__ = sorted(set(__all__))
# Create sphinx documentation.
__doc__ = "\n\n".join(
[
"""
{0}
----------------------------------------------------------------
.. autoclass:: pyro.distributions.{0}
""".format(
_name
)
for _name in sorted(__all__)
# Work around sphinx autodoc error in case two InverseGamma's are defined:
# "duplicate object description of pyro.distributions.InverseGamma"
if _name != "InverseGamma"
]
)