# Source code for pyro.distributions.torch

# Copyright (c) 2017-2019 Uber Technologies, Inc.

import math

import torch

from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.ops.special import log_binomial

from . import constraints

def _clamp_by_zero(x):
# works like clamp(x, min=0) but has grad at 0 is 0.5
return (x.clamp(min=0) + x - x.clamp(max=0)) / 2

[docs]class Beta(torch.distributions.Beta, TorchDistributionMixin):
def conjugate_update(self, other):
"""
EXPERIMENTAL.
"""
assert isinstance(other, Beta)
concentration1 = self.concentration1 + other.concentration1 - 1
concentration0 = self.concentration0 + other.concentration0 - 1
updated = Beta(concentration1, concentration0)

def _log_normalizer(d):
x = d.concentration1
y = d.concentration0
return (x + y).lgamma() - x.lgamma() - y.lgamma()

log_normalizer = (
_log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated)
)
return updated, log_normalizer

[docs]class Binomial(torch.distributions.Binomial, TorchDistributionMixin):
# EXPERIMENTAL threshold on total_count above which sampling will use a
# clamped Poisson approximation for Binomial samples. This is useful for
# sampling very large populations.
approx_sample_thresh = math.inf

# EXPERIMENTAL If set to a positive value, the .log_prob() method will use
# a shifted Sterling's approximation to the Beta function, reducing
# computational cost from 3 lgamma() evaluations to 4 log() evaluations
# plus arithmetic. Recommended values are between 0.1 and 0.01.
approx_log_prob_tol = 0.0

def sample(self, sample_shape=torch.Size()):
if self.approx_sample_thresh < math.inf:
exact = self.total_count <= self.approx_sample_thresh
if not exact.all():
# Approximate large counts with a moment-matched clamped Poisson.
shape = self._extended_shape(sample_shape)
p = self.probs
q = 1 - self.probs
mean = torch.min(p, q) * self.total_count
variance = p * q * self.total_count
shift = (mean - variance).round()
result = torch.poisson(variance.expand(shape))
result = torch.min(result + shift, self.total_count)
sample = torch.where(p < q, result, self.total_count - result)
# Draw exact samples for remaining items.
if exact.any():
total_count = torch.where(
exact, self.total_count, torch.zeros_like(self.total_count)
)
exact_sample = torch.distributions.Binomial(
total_count, self.probs, validate_args=False
).sample(sample_shape)
sample = torch.where(exact, exact_sample, sample)
return sample
return super().sample(sample_shape)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

n = self.total_count
k = value
# k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
#     (case logit < 0)              = k * logit - n * log1p(e^logit)
#     (case logit > 0)              = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
#                                   = k * logit - n * logit - n * log1p(e^-logit)
#     (merge two cases)             = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
normalize_term = n * (
_clamp_by_zero(self.logits) + self.logits.abs().neg().exp().log1p()
)
return (
k * self.logits
- normalize_term
+ log_binomial(n, k, tol=self.approx_log_prob_tol)
)

# This overloads .log_prob() and .enumerate_support() to speed up evaluating
# log_prob on the support of this variable: we can completely avoid tensor ops
# and merely reshape the self.logits tensor. This is especially important for
# Pyro models that use enumeration.
[docs]class Categorical(torch.distributions.Categorical, TorchDistributionMixin):
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}

def log_prob(self, value):
if getattr(value, "_pyro_categorical_support", None) == id(self):
# Assume value is a reshaped torch.arange(event_shape[0]).
# In this case we can call .reshape() rather than torch.gather().
if not torch._C._get_tracing_state():
if self._validate_args:
self._validate_sample(value)
assert value.size(0) == self.logits.size(-1)
logits = self.logits
if logits.dim() <= value.dim():
logits = logits.reshape(
(1,) * (1 + value.dim() - logits.dim()) + logits.shape
)
if not torch._C._get_tracing_state():
assert logits.size(-1 - value.dim()) == 1
return logits.transpose(-1 - value.dim(), -1).squeeze(-1)
return super().log_prob(value)

def enumerate_support(self, expand=True):
result = super().enumerate_support(expand=expand)
if not expand:
result._pyro_categorical_support = id(self)
return result

[docs]class Dirichlet(torch.distributions.Dirichlet, TorchDistributionMixin):
@staticmethod
def infer_shapes(concentration):
batch_shape = concentration[:-1]
event_shape = concentration[-1:]
return batch_shape, event_shape

def conjugate_update(self, other):
"""
EXPERIMENTAL.
"""
assert isinstance(other, Dirichlet)
concentration = self.concentration + other.concentration - 1
updated = Dirichlet(concentration)

def _log_normalizer(d):
c = d.concentration
return c.sum(-1).lgamma() - c.lgamma().sum(-1)

log_normalizer = (
_log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated)
)
return updated, log_normalizer

[docs]class Gamma(torch.distributions.Gamma, TorchDistributionMixin):
def conjugate_update(self, other):
"""
EXPERIMENTAL.
"""
assert isinstance(other, Gamma)
concentration = self.concentration + other.concentration - 1
rate = self.rate + other.rate
updated = Gamma(concentration, rate)

def _log_normalizer(d):
c = d.concentration
return d.rate.log() * c - c.lgamma()

log_normalizer = (
_log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated)
)
return updated, log_normalizer

[docs]class Geometric(torch.distributions.Geometric, TorchDistributionMixin):
# TODO: move upstream
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return (-value - 1) * torch.nn.functional.softplus(self.logits) + self.logits

[docs]class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin):
def __init__(self, loc, scale, validate_args=None):
base_dist = Normal(loc, scale)
# This differs from torch.distributions.LogNormal only in that base_dist is
# a pyro.distributions.Normal rather than a torch.distributions.Normal.
super(torch.distributions.LogNormal, self).__init__(
base_dist,
torch.distributions.transforms.ExpTransform(),
validate_args=validate_args,
)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LogNormal, _instance)
return super(torch.distributions.LogNormal, self).expand(
batch_shape, _instance=new
)

[docs]class LowRankMultivariateNormal(
torch.distributions.LowRankMultivariateNormal, TorchDistributionMixin
):
@staticmethod
def infer_shapes(loc, cov_factor, cov_diag):
event_shape = loc[-1:]
return batch_shape, event_shape

[docs]class MultivariateNormal(
torch.distributions.MultivariateNormal, TorchDistributionMixin
):
@staticmethod
def infer_shapes(
loc, covariance_matrix=None, precision_matrix=None, scale_tril=None
):
batch_shape, event_shape = loc[:-1], loc[-1:]
for matrix in [covariance_matrix, precision_matrix, scale_tril]:
if matrix is not None:
return batch_shape, event_shape

[docs]class Multinomial(torch.distributions.Multinomial, TorchDistributionMixin):
def infer_shapes(total_count=None, probs=None, logits=None):
tensor = probs if logits is None else logits
batch_shape, event_shape = tensor[:-1], tensor[-1:]
if isinstance(total_count, tuple):
return batch_shape, event_shape

[docs]class Normal(torch.distributions.Normal, TorchDistributionMixin):
pass

[docs]class OneHotCategorical(torch.distributions.OneHotCategorical, TorchDistributionMixin):
@staticmethod
def infer_shapes(probs=None, logits=None):
tensor = probs if logits is None else logits
event_shape = tensor[-1:]
batch_shape = tensor[:-1]
return batch_shape, event_shape

[docs]class Poisson(torch.distributions.Poisson, TorchDistributionMixin):
def __init__(self, rate, *, is_sparse=False, validate_args=None):
self.is_sparse = is_sparse
super().__init__(rate, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Poisson, _instance)
new = super().expand(batch_shape, _instance=new)
new.is_sparse = self.is_sparse
return new

def log_prob(self, value):
if not self.is_sparse:
return super().log_prob(value)
if self._validate_args:
self._validate_sample(value)
rate, value, nonzero = torch.broadcast_tensors(self.rate, value, value > 0)
sparse_rate = rate[nonzero]
sparse_value = value[nonzero]
return (
nonzero,
(sparse_rate.log() * sparse_value) - (sparse_value + 1).lgamma(),
)
- rate
)

[docs]class Independent(torch.distributions.Independent, TorchDistributionMixin):
@staticmethod
def infer_shapes(**kwargs):
raise NotImplementedError

@property
def _validate_args(self):
return self.base_dist._validate_args

@_validate_args.setter
def _validate_args(self, value):
self.base_dist._validate_args = value

def conjugate_update(self, other):
"""
EXPERIMENTAL.
"""
n = self.reintepreted_batch_ndims
updated, log_normalizer = self.base_dist.conjugate_update(other.to_event(-n))
updated = updated.to_event(n)
log_normalizer = sum_rightmost(log_normalizer, n)
return updated, log_normalizer

[docs]class Uniform(torch.distributions.Uniform, TorchDistributionMixin):
def __init__(self, low, high, validate_args=None):
super().__init__(low, high, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Uniform, _instance)
new = super().expand(batch_shape, _instance=new)
return new

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):

# Programmatically load all distributions from PyTorch.
__all__ = []
for _name, _Dist in torch.distributions.__dict__.items():
if not isinstance(_Dist, type):
continue
if not issubclass(_Dist, torch.distributions.Distribution):
continue
if _Dist is torch.distributions.Distribution:
continue

try:
_PyroDist = locals()[_name]
except KeyError:
_PyroDist = type(_name, (_Dist, TorchDistributionMixin), {})
_PyroDist.__module__ = __name__
locals()[_name] = _PyroDist

_PyroDist.__doc__ = """
Wraps :class:{}.{} with
:class:~pyro.distributions.torch_distribution.TorchDistributionMixin.
""".format(
_Dist.__module__, _Dist.__name__
)

__all__.append(_name)

# Create sphinx documentation.
__doc__ = "\n\n".join(
[
"""
{0}
----------------------------------------------------------------
.. autoclass:: pyro.distributions.{0}
""".format(
_name
)
for _name in sorted(__all__)
]
)