Source code for pyro.contrib.epidemiology.distributions

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
from contextlib import contextmanager

import torch
from torch.distributions.utils import broadcast_all

import pyro.distributions as dist
from pyro.distributions.util import is_validation_enabled

_RELAX = False
_RELAX_MIN_VARIANCE = 0.1


def _all(x):
    return x.all() if isinstance(x, torch.Tensor) else x


def _is_zero(x):
    return _all(x == 0)


[docs]@contextmanager def set_approx_sample_thresh(thresh): """ EXPERIMENTAL Context manager / decorator to temporarily set the global default value of ``Binomial.approx_sample_thresh``, thereby decreasing the computational complexity of sampling from :class:`~pyro.distributions.Binomial`, :class:`~pyro.distributions.BetaBinomial`, :class:`~pyro.distributions.ExtendedBinomial`, :class:`~pyro.distributions.ExtendedBetaBinomial`, and distributions returned by :func:`infection_dist`. This is useful for sampling from very large ``total_count``. This is used internally by :class:`~pyro.contrib.epidemiology.compartmental.CompartmentalModel`. :param thresh: New temporary threshold. :type thresh: int or float. """ assert isinstance(thresh, (float, int)) assert thresh > 0 old = dist.Binomial.approx_sample_thresh try: dist.Binomial.approx_sample_thresh = thresh yield finally: dist.Binomial.approx_sample_thresh = old
[docs]@contextmanager def set_approx_log_prob_tol(tol): """ EXPERIMENTAL Context manager / decorator to temporarily set the global default value of ``Binomial.approx_log_prob_tol`` and ``BetaBinomial.approx_log_prob_tol``, thereby decreasing the computational complexity of scoring :class:`~pyro.distributions.Binomial` and :class:`~pyro.distributions.BetaBinomial` distributions. This is used internally by :class:`~pyro.contrib.epidemiology.compartmental.CompartmentalModel`. :param tol: New temporary tolold. :type tol: int or float. """ assert isinstance(tol, (float, int)) assert tol >= 0 old1 = dist.Binomial.approx_log_prob_tol old2 = dist.BetaBinomial.approx_log_prob_tol try: dist.Binomial.approx_log_prob_tol = tol dist.BetaBinomial.approx_log_prob_tol = tol yield finally: dist.Binomial.approx_log_prob_tol = old1 dist.BetaBinomial.approx_log_prob_tol = old2
@contextmanager def set_relaxed_distributions(relaxed=True): global _RELAX old = _RELAX try: _RELAX = relaxed yield finally: _RELAX = old def _validate_overdispersion(overdispersion): if is_validation_enabled(): if not _all(0 <= overdispersion): raise ValueError("Expected overdispersion >= 0") if not _all(overdispersion < 2): raise ValueError("Expected overdispersion < 2") def _relaxed_binomial(total_count, probs): """ Returns a moment-matched :class:`~pyro.distributions.Normal` approximating a :class:`~pyro.distributions.Binomial` but allowing arbitrary real ``total_count`` and lower-bounding variance. """ total_count, probs = broadcast_all(total_count, probs) mean = probs * total_count variance = total_count * probs * (1 - probs) scale = variance.clamp(min=_RELAX_MIN_VARIANCE).sqrt() return dist.Normal(mean, scale) def _relaxed_beta_binomial(concentration1, concentration0, total_count): """ Returns a moment-matched :class:`~pyro.distributions.Normal` approximating a :class:`~pyro.distributions.BetaBinomial` but allowing arbitrary real ``total_count`` and lower-bounding variance. """ concentration1, concentration0, total_count = broadcast_all( concentration1, concentration0, total_count ) c = concentration1 + concentration0 beta_mean = concentration1 / c beta_variance = concentration1 * concentration0 / (c * c * (c + 1)) mean = beta_mean * total_count variance = beta_variance * total_count * (c + total_count) scale = variance.clamp(min=_RELAX_MIN_VARIANCE).sqrt() return dist.Normal(mean, scale)
[docs]def binomial_dist(total_count, probs, *, overdispersion=0.0): """ Returns a Beta-Binomial distribution that is an overdispersed version of a Binomial distribution, according to a parameter ``overdispersion``, typically set in the range 0.1 to 0.5. This is useful for (1) fitting real data that is overdispersed relative to a Binomial distribution, and (2) relaxing models of large populations to improve inference. In particular the ``overdispersion`` parameter lower bounds the relative uncertainty in stochastic models such that increasing population leads to a limiting scale-free dynamical system with bounded stochasticity, in contrast to Binomial-based SDEs that converge to deterministic ODEs in the large population limit. This parameterization satisfies the following properties: 1. Variance increases monotonically in ``overdispersion``. 2. ``overdispersion = 0`` results in a Binomial distribution. 3. ``overdispersion`` lower bounds the relative uncertainty ``std_dev / (total_count * p * q)``, where ``probs = p = 1 - q``, and serves as an asymptote for relative uncertainty as ``total_count → ∞``. This contrasts the Binomial whose relative uncertainty tends to zero. 4. If ``X ~ binomial_dist(n, p, overdispersion=σ)`` then in the large population limit ``n → ∞``, the scaled random variable ``X / n`` converges in distribution to ``LogitNormal(log(p/(1-p)), σ)``. To achieve these properties we set ``p = probs``, ``q = 1 - p``, and:: concentration = 1 / (p * q * overdispersion**2) - 1 :param total_count: Number of Bernoulli trials. :type total_count: int or torch.Tensor :param probs: Event probabilities. :type probs: float or torch.Tensor :param overdispersion: Amount of overdispersion, in the half open interval [0,2). Defaults to zero. :type overdispersion: float or torch.tensor """ _validate_overdispersion(overdispersion) if _is_zero(overdispersion): if _RELAX: return _relaxed_binomial(total_count, probs) return dist.ExtendedBinomial(total_count, probs) p = probs q = 1 - p od2 = (overdispersion + 1e-8) ** 2 concentration1 = 1 / (q * od2 + 1e-8) - p concentration0 = 1 / (p * od2 + 1e-8) - q # At this point we have # concentration1 + concentration0 == 1 / (p + q + od2 + 1e-8) - 1 if _RELAX: return _relaxed_beta_binomial(concentration1, concentration0, total_count) return dist.ExtendedBetaBinomial(concentration1, concentration0, total_count)
[docs]def beta_binomial_dist( concentration1, concentration0, total_count, *, overdispersion=0.0 ): """ Returns a Beta-Binomial distribution that is an overdispersed version of a the usual Beta-Binomial distribution, according to an extra parameter ``overdispersion``, typically set in the range 0.1 to 0.5. :param concentration1: 1st concentration parameter (alpha) for the Beta distribution. :type concentration1: float or torch.Tensor :param concentration0: 2nd concentration parameter (beta) for the Beta distribution. :type concentration0: float or torch.Tensor :param total_count: Number of Bernoulli trials. :type total_count: float or torch.Tensor :param overdispersion: Amount of overdispersion, in the half open interval [0,2). Defaults to zero. :type overdispersion: float or torch.tensor """ _validate_overdispersion(overdispersion) if not _is_zero(overdispersion): # Compute harmonic sum of two sources of concentration resulting in # final concentration c = 1 / (1 / c_1 + 1 / c_2) od2 = (overdispersion + 1e-8) ** 2 c_1 = concentration1 + concentration0 c_2 = c_1**2 / (concentration1 * concentration0 * od2 + 1e-8) - 1 factor = 1 + c_1 / c_2 concentration1 = concentration1 / factor concentration0 = concentration0 / factor if _RELAX: return _relaxed_beta_binomial(concentration1, concentration0, total_count) return dist.ExtendedBetaBinomial(concentration1, concentration0, total_count)
def poisson_dist(rate, *, overdispersion=0.0): _validate_overdispersion(overdispersion) if _is_zero(overdispersion): return dist.Poisson(rate) raise NotImplementedError("TODO return a NegativeBinomial or GammaPoisson") def negative_binomial_dist( concentration, probs=None, *, logits=None, overdispersion=0.0 ): _validate_overdispersion(overdispersion) if _is_zero(overdispersion): return dist.NegativeBinomial(concentration, probs=probs, logits=logits) raise NotImplementedError("TODO return a NegativeBinomial or GammaPoisson")
[docs]def infection_dist( *, individual_rate, num_infectious, num_susceptible=math.inf, population=math.inf, concentration=math.inf, overdispersion=0.0 ): """ Create a :class:`~pyro.distributions.Distribution` over the number of new infections at a discrete time step. This returns a Poisson, Negative-Binomial, Binomial, or Beta-Binomial distribution depending on whether ``population`` and ``concentration`` are finite. In Pyro models, the population is usually finite. In the limit ``population → ∞`` and ``num_susceptible/population → 1``, the Binomial converges to Poisson and the Beta-Binomial converges to Negative-Binomial. In the limit ``concentration → ∞``, the Negative-Binomial converges to Poisson and the Beta-Binomial converges to Binomial. The overdispersed distributions (Negative-Binomial and Beta-Binomial returned when ``concentration < ∞``) are useful for modeling superspreader individuals [1,2]. The finitely supported distributions Binomial and Negative-Binomial are useful in small populations and in probabilistic programming systems where truncation or censoring are expensive [3]. **References** [1] J. O. Lloyd-Smith, S. J. Schreiber, P. E. Kopp, W. M. Getz (2005) "Superspreading and the effect of individual variation on disease emergence" https://www.nature.com/articles/nature04153.pdf [2] Lucy M. Li, Nicholas C. Grassly, Christophe Fraser (2017) "Quantifying Transmission Heterogeneity Using Both Pathogen Phylogenies and Incidence Time Series" https://academic.oup.com/mbe/article/34/11/2982/3952784 [3] Lawrence Murray et al. (2018) "Delayed Sampling and Automatic Rao-Blackwellization of Probabilistic Programs" https://arxiv.org/pdf/1708.07787.pdf :param individual_rate: The mean number of infections per infectious individual per time step in the limit of large population, equal to ``R0 / tau`` where ``R0`` is the basic reproductive number and ``tau`` is the mean duration of infectiousness. :param num_infectious: The number of infectious individuals at this time step, sometimes ``I``, sometimes ``E+I``. :param num_susceptible: The number ``S`` of susceptible individuals at this time step. This defaults to an infinite population. :param population: The total number of individuals in a population. This defaults to an infinite population. :param concentration: The concentration or dispersion parameter ``k`` in overdispersed models of superspreaders [1,2]. This defaults to minimum variance ``concentration = ∞``. :param overdispersion: Amount of overdispersion, in the half open interval [0,2). Defaults to zero. :type overdispersion: float or torch.tensor """ # Convert to colloquial variable names. R = individual_rate I = num_infectious S = num_susceptible N = population k = concentration if isinstance(N, float) and N == math.inf: if isinstance(k, float) and k == math.inf: # Return a Poisson distribution. return poisson_dist(R * I, overdispersion=overdispersion) else: # Return an overdispersed Negative-Binomial distribution. combined_k = k * I logits = torch.as_tensor(R / k).log() return negative_binomial_dist( combined_k, logits=logits, overdispersion=overdispersion ) else: # Compute the probability that any given (susceptible, infectious) # pair of individuals results in an infection at this time step. p = torch.as_tensor(R / N).clamp(max=1 - 1e-6) # Combine infections from all individuals. combined_p = p.neg().log1p().mul(I).expm1().neg() # = 1 - (1 - p)**I combined_p = combined_p.clamp(min=1e-6) if isinstance(k, float) and k == math.inf: # Return a pure Binomial model, combining the independent Binomial # models of each infectious individual. return binomial_dist(S, combined_p, overdispersion=overdispersion) else: # Return an overdispersed Beta-Binomial model, combining # independent BetaBinomial(c1,c0,S) models for each infectious # individual. c1 = (k * I).clamp(min=1e-6) c0 = c1 * (combined_p.reciprocal() - 1).clamp(min=1e-6) return beta_binomial_dist(c1, c0, S, overdispersion=overdispersion)