# Source code for pyro.distributions.extended

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

from pyro.distributions import constraints

from .conjugate import BetaBinomial
from .torch import Binomial

[docs]class ExtendedBinomial(Binomial):
"""
EXPERIMENTAL :class:~pyro.distributions.Binomial distribution extended to
have logical support the entire integers and to allow arbitrary integer
total_count. Numerical support is still the integer interval [0,
total_count].
"""

arg_constraints = {
"total_count": constraints.integer,
"probs": constraints.unit_interval,
"logits": constraints.real,
}
support = constraints.integer

[docs]    def log_prob(self, value):
result = super().log_prob(value)
invalid = (value < 0) | (value > self.total_count)
return result.masked_fill(invalid, -math.inf)

[docs]class ExtendedBetaBinomial(BetaBinomial):
"""
EXPERIMENTAL :class:~pyro.distributions.BetaBinomial distribution
extended to have logical support the entire integers and to allow arbitrary
integer total_count. Numerical support is still the integer interval
[0, total_count].
"""

arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
"total_count": constraints.integer,
}
support = constraints.integer

[docs]    def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

total_count = self.total_count
invalid = (value < 0) | (value > total_count)
n = total_count.clamp(min=0)
k = value.masked_fill(invalid, 0)

try:
self.total_count = n
result = super().log_prob(k)
finally:
self.total_count = total_count

return result.masked_fill(invalid, -math.inf)