Source code for pyro.distributions.extended

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

from pyro.distributions import constraints

from .conjugate import BetaBinomial
from .torch import Binomial


[docs]class ExtendedBinomial(Binomial): """ EXPERIMENTAL :class:`~pyro.distributions.Binomial` distribution extended to have logical support the entire integers and to allow arbitrary integer ``total_count``. Numerical support is still the integer interval ``[0, total_count]``. """ arg_constraints = {"total_count": constraints.integer, "probs": constraints.unit_interval, "logits": constraints.real} support = constraints.integer
[docs] def log_prob(self, value): result = super().log_prob(value) invalid = (value < 0) | (value > self.total_count) return result.masked_fill(invalid, -math.inf)
[docs]class ExtendedBetaBinomial(BetaBinomial): """ EXPERIMENTAL :class:`~pyro.distributions.BetaBinomial` distribution extended to have logical support the entire integers and to allow arbitrary integer ``total_count``. Numerical support is still the integer interval ``[0, total_count]``. """ arg_constraints = {"concentration1": constraints.positive, "concentration0": constraints.positive, "total_count": constraints.integer} support = constraints.integer
[docs] def log_prob(self, value): if self._validate_args: self._validate_sample(value) total_count = self.total_count invalid = (value < 0) | (value > total_count) n = total_count.clamp(min=0) k = value.masked_fill(invalid, 0) try: self.total_count = n result = super().log_prob(k) finally: self.total_count = total_count return result.masked_fill(invalid, -math.inf)