# Source code for pyro.distributions.log_normal_negative_binomial

# Copyright Contributors to the Pyro project.

import torch
from torch.distributions import constraints

from pyro.distributions.torch import NegativeBinomial
from pyro.distributions.torch_distribution import TorchDistribution

[docs]class LogNormalNegativeBinomial(TorchDistribution):
r"""
A three-parameter generalization of the Negative Binomial distribution .
It can be understood as a continuous mixture of Negative Binomial distributions
in which we inject Normally-distributed noise into the logits of the Negative
Binomial distribution:

.. math::

\begin{eqnarray}
&\rm{LNNB}(y | \rm{total\_count}=\nu, \rm{logits}=\ell, \rm{multiplicative\_noise\_scale}=sigma) = \\
&\int d\epsilon \mathcal{N}(\epsilon | 0, \sigma)
\rm{NB}(y | \rm{total\_count}=\nu, \rm{logits}=\ell + \epsilon)
\end{eqnarray}

where :math:y \ge 0 is a non-negative integer. Thus while a Negative Binomial distribution
can be formulated as a Poisson distribution with a Gamma-distributed rate, this distribution
adds an additional level of variability by also modulating the rate by Log Normally-distributed
multiplicative noise.

This distribution has a mean given by

.. math::
\mathbb{E}[y] = \nu e^{\ell} = e^{\ell + \log \nu + \tfrac{1}{2}\sigma^2}

and a variance given by

.. math::
\rm{Var}[y] = \mathbb{E}[y] + \left( e^{\sigma^2} (1 + 1/\nu) - 1 \right) \left( \mathbb{E}[y] \right)^2

Thus while a given mean and variance together uniquely characterize a Negative Binomial distribution, there is a
one-dimensional family of Log Normal Negative Binomial distributions with a given mean and variance.

Note that in some applications it may be useful to parameterize the logits as

.. math::
\ell = \ell^\prime - \log \nu - \tfrac{1}{2}\sigma^2

so that the mean is given by :math:\mathbb{E}[y] = e^{\ell^\prime} and does not depend on :math:\nu
and :math:\sigma, which serve to determine the higher moments.

References:

 "Lognormal and Gamma Mixed Negative Binomial Regression,"
Mingyuan Zhou, Lingbo Li, David Dunson, and Lawrence Carin.

:param total_count: non-negative number of negative Bernoulli trials. The variance decreases
as total_count increases.
:type total_count: float or torch.Tensor
:param torch.Tensor logits: Event log-odds for probabilities of success for underlying
Negative Binomial distribution.
:param torch.Tensor multiplicative_noise_scale: Controls the level of the injected Normal logit noise.
:param int num_quad_points: Number of quadrature points used to compute the (approximate) log_prob.
Defaults to 8.
"""
arg_constraints = {
"total_count": constraints.greater_than_eq(0),
"logits": constraints.real,
"multiplicative_noise_scale": constraints.positive,
}
support = constraints.nonnegative_integer

def __init__(
self,
total_count,
logits,
multiplicative_noise_scale,
*,
validate_args=None,
):

total_count, logits, multiplicative_noise_scale
)

logits.unsqueeze(-1)
)
self.nb_dist = NegativeBinomial(
)

self.multiplicative_noise_scale = multiplicative_noise_scale
self.total_count = total_count
self.logits = logits

multiplicative_noise_scale.shape, self.nb_dist.batch_shape[:-1]
)
event_shape = torch.Size()

super().__init__(batch_shape, event_shape, validate_args)

[docs]    def log_prob(self, value):
nb_log_prob = self.nb_dist.log_prob(value.unsqueeze(-1))

[docs]    def sample(self, sample_shape=torch.Size()):
raise NotImplementedError

[docs]    def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(type(self), _instance)
batch_shape = torch.Size(batch_shape)
total_count = self.total_count.expand(batch_shape)
logits = self.logits.expand(batch_shape)
multiplicative_noise_scale = self.multiplicative_noise_scale.expand(batch_shape)
LogNormalNegativeBinomial.__init__(
new,
total_count,
logits,
multiplicative_noise_scale,
validate_args=False,
)
new._validate_args = self._validate_args
return new

@lazy_property
def mean(self):
self.logits
+ self.total_count.log()
+ 0.5 * self.multiplicative_noise_scale.pow(2.0)
)

@lazy_property
def variance(self):
kappa = (
torch.exp(self.multiplicative_noise_scale.pow(2.0))
* (1 + 1 / self.total_count)
- 1
)
return self.mean + kappa * self.mean.pow(2.0)