Epidemiology¶
Warning
Code in pyro.contrib.epidemiology
is under development.
This code makes no guarantee about maintaining backwards compatibility.
pyro.contrib.epidemiology
provides a modeling language for a class of
stochastic discrete-time discrete-count compartmental models. This module
implements black-box inference (both Stochastic Variational Inference and
Hamiltonian Monte Carlo), prediction of latent variables, and
forecasting of future trajectories.
For example usage see the following tutorials:
Base Compartmental Model¶
- class CompartmentalModel(compartments, duration, population, *, approximate=())[source]¶
Bases:
abc.ABC
Abstract base class for discrete-time discrete-value stochastic compartmental models.
Derived classes must implement methods
initialize()
andtransition()
. Derived classes may optionally implementglobal_model()
,compute_flows()
, andheuristic()
.Example usage:
# First implement a concrete derived class. class MyModel(CompartmentalModel): def __init__(self, ...): ... def global_model(self): ... def initialize(self, params): ... def transition(self, params, state, t): ... # Run inference to fit the model to data. model = MyModel(...) model.fit_svi(num_samples=100) # or .fit_mcmc(...) R0 = model.samples["R0"] # An example parameter. print("R0 = {:0.3g} ± {:0.3g}".format(R0.mean(), R0.std())) # Predict latent variables. samples = model.predict() # Forecast forward. samples = model.predict(forecast=30) # You can assess future interventions (applied after ``duration``) by # storing them as attributes that are read by your derived methods. model.my_intervention = False samples1 = model.predict(forecast=30) model.my_intervention = True samples2 = model.predict(forecast=30) effect = samples2["my_result"].mean() - samples1["my_result"].mean() print("average effect = {:0.3g}".format(effect))
An example workflow is to use cheaper approximate inference while finding good model structure and priors, then move to more accurate but more expensive inference once the model is plausible.
Start with
.fit_svi(guide_rank=0, num_steps=2000)
for cheap inference while you search for a good model.Additionally infer long-range correlations by moving to a low-rank multivariate normal guide via
.fit_svi(guide_rank=None, num_steps=5000)
.Optionally additionally infer non-Gaussian posterior by moving to the more expensive (but still approximate via moment matching)
.fit_mcmc(num_quant_bins=1, num_samples=10000, num_chains=2)
.Optionally improve fit around small counts by moving the the more expensive enumeration-based algorithm
.fit_mcmc(num_quant_bins=4, num_samples=10000, num_chains=2)
(GPU recommended).
- Variables
samples (dict) – Dictionary of posterior samples.
- Parameters
compartments (list) – A list of strings of compartment names.
duration (int) – The number of discrete time steps in this model.
population (int or torch.Tensor) – Either the total population of a single-region model or a tensor of each region’s population in a regional model.
approximate (tuple) – Names of compartments for which pointwise approximations should be provided in
transition()
, e.g. if you specifyapproximate=("I")
then thestate["I_approx"]
will be a continuous-valued non-enumerated point estimate ofstate["I"]
. Approximations are useful to reduce computational cost. Approximations are continuous-valued with support(-0.5, population + 0.5)
.
- property time_plate¶
A
pyro.plate
for the time dimension.
- property region_plate¶
Either a
pyro.plate
or a trivialExitStack
depending on whether this model.is_regional
.
- property full_mass¶
A list of a single tuple of the names of global random variables.
- property series¶
A frozenset of names of sample sites that are sampled each time step.
- global_model()[source]¶
Samples and returns any global parameters.
- Returns
An arbitrary object of parameters (e.g.
None
or a tuple).
- abstract initialize(params)[source]¶
Returns initial counts in each compartment.
- Parameters
params – The global params returned by
global_model()
.- Returns
A dict mapping compartment name to initial value.
- Return type
- abstract transition(params, state, t)[source]¶
Forward generative process for dynamics.
This inputs a current
state
and stochastically updates that state in-place.Note that this method is called under multiple different interpretations, including batched and vectorized interpretations. During
generate()
this is called to generate a single sample. Duringheuristic()
this is called to generate a batch of samples for SMC. Duringfit_mcmc()
this is called both in vectorized form (vectorizing over time) and in sequential form (for a single time step); both forms enumerate over discrete latent variables. Duringpredict()
this is called to forecast a batch of samples, conditioned on posterior samples for the time interval[0:duration]
.- Parameters
params – The global params returned by
global_model()
.state (dict) – A dictionary mapping compartment name to current tensor value. This should be updated in-place.
t (int or slice) – A time-like index. During inference
t
may be either a slice (for vectorized inference) or an integer time index. During predictiont
will be integer time index.
- finalize(params, prev, curr)[source]¶
Optional method for likelihoods that depend on entire time series.
This should be used only for non-factorizable likelihoods that couple states across time. Factorizable likelihoods should instead be added to the
transition()
method, thereby enabling their use inheuristic()
initialization. Since this method is called only after the last time step, it is not used inheuristic()
initialization.Warning
This currently does not support latent variables.
- Parameters
params – The global params returned by
global_model()
.prev (dict) –
curr (dict) – Dictionaries mapping compartment name to tensor of entire time series. These two parameters are offset by 1 step, thereby making it easy to compute time series of fluxes. For quantized inference, this uses the approximate point estimates, so users must request any needed time series in
__init__()
, e.g. by callingsuper().__init__(..., approximate=("I", "E"))
if likelihood depends on theI
andE
time series.
- compute_flows(prev, curr, t)[source]¶
Computes flows between compartments, given compartment populations before and after time step t.
The default implementation assumes sequential flows terminating in an implicit compartment named “R”. For example if:
compartment_names = ("S", "E", "I")
the default implementation computes at time step
t = 9
:flows["S2E_9"] = prev["S"] - curr["S"] flows["E2I_9"] = prev["E"] - curr["E"] + flows["S2E_9"] flows["I2R_9"] = prev["I"] - curr["I"] + flows["E2I_9"]
For more complex flows (non-sequential, branching, looping, duplicating, etc.), users may override this method.
- Parameters
- Returns
A dict mapping flow name to tensor value.
- Return type
- generate(fixed={})[source]¶
Generate data from the prior.
- Pram dict fixed
A dictionary of parameters on which to condition. These must be top-level parentless nodes, i.e. have no upstream stochastic dependencies.
- Returns
A dictionary mapping sample site name to sampled value.
- Return type
- fit_svi(*, num_samples=100, num_steps=2000, num_particles=32, learning_rate=0.1, learning_rate_decay=0.01, betas=(0.8, 0.99), haar=True, init_scale=0.01, guide_rank=0, jit=False, log_every=200, **options)[source]¶
Runs stochastic variational inference to generate posterior samples.
This runs
SVI
, setting the.samples
attribute on completion.This approximate inference method is useful for quickly iterating on probabilistic models.
- Parameters
num_samples (int) – Number of posterior samples to draw from the trained guide. Defaults to 100.
learning_rate (int) – Learning rate for the
ClippedAdam
optimizer.learning_rate_decay (int) – Learning rate for the
ClippedAdam
optimizer. Note this is decay over the entire schedule, not per-step decay.betas (tuple) – Momentum parameters for the
ClippedAdam
optimizer.haar (bool) – Whether to use a Haar wavelet reparameterizer.
guide_rank (int) – Rank of the auto normal guide. If zero (default) use an
AutoNormal
guide. If a positive integer or None, use anAutoLowRankMultivariateNormal
guide. If the string “full”, use anAutoMultivariateNormal
guide. These latter two require morenum_steps
to fit.init_scale (float) – Initial scale of the
AutoLowRankMultivariateNormal
guide.jit (bool) – Whether to use a jit compiled ELBO.
log_every (int) – How often to log svi losses.
heuristic_num_particles (int) – Passed to
heuristic()
asnum_particles
. Defaults to 1024.
- Returns
Time series of SVI losses (useful to diagnose convergence).
- Return type
- fit_mcmc(**options)[source]¶
Runs NUTS inference to generate posterior samples.
This uses the
NUTS
kernel to runMCMC
, setting the.samples
attribute on completion.This uses an asymptotically exact enumeration-based model when
num_quant_bins > 1
, and a cheaper moment-matched approximate model whennum_quant_bins == 1
.- Parameters
**options – Options passed to
MCMC
. The remaining options are pulled out and have special meaning.num_samples (int) – Number of posterior samples to draw via mcmc. Defaults to 100.
max_tree_depth (int) – (Default 5). Max tree depth of the
NUTS
kernel.full_mass – Specification of mass matrix of the
NUTS
kernel. Defaults to full mass over global random variables.arrowhead_mass (bool) – Whether to treat
full_mass
as the head of an arrowhead matrix versus simply as a block. Defaults to False.num_quant_bins (int) – If greater than 1, use asymptotically exact inference via local enumeration over this many quantization bins. If equal to 1, use continuous-valued relaxed approximate inference. Note that computational cost is exponential in num_quant_bins. Defaults to 1 for relaxed inference.
haar (bool) – Whether to use a Haar wavelet reparameterizer. Defaults to True.
haar_full_mass (int) – Number of low frequency Haar components to include in the full mass matrix. If
haar=False
then this is ignored. Defaults to 10.heuristic_num_particles (int) – Passed to
heuristic()
asnum_particles
. Defaults to 1024.
- Returns
An MCMC object for diagnostics, e.g.
MCMC.summary()
.- Return type
- predict(forecast=0)[source]¶
Predict latent variables and optionally forecast forward.
This may be run only after
fit_mcmc()
and draws the samenum_samples
as passed tofit_mcmc()
.
- heuristic(num_particles=1024, ess_threshold=0.5, retries=10)[source]¶
Finds an initial feasible guess of all latent variables, consistent with observed data. This is needed because not all hypotheses are feasible and HMC needs to start at a feasible solution to progress.
The default implementation attempts to find a feasible state using
SMCFilter
with proprosals from the prior. However this method may be overridden in cases where SMC performs poorly e.g. in high-dimensional models.
Example Models¶
Simple SIR¶
- class SimpleSIRModel(population, recovery_time, data)[source]¶
Susceptible-Infected-Recovered model.
To customize this model we recommend forking and editing this class.
This is a stochastic discrete-time discrete-state model with three compartments: “S” for susceptible, “I” for infected, and “R” for recovered individuals (the recovered individuals are implicit:
R = population - S - I
) with transitionsS -> I -> R
.- Parameters
population (int) – Total
population = S + I + R
.recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.data (iterable) – Time series of new observed infections. Each time step is Binomial distributed between 0 and the number of
S -> I
transitions. This allows false negative but no false positives.
Simple SEIR¶
- class SimpleSEIRModel(population, incubation_time, recovery_time, data)[source]¶
Susceptible-Exposed-Infected-Recovered model.
To customize this model we recommend forking and editing this class.
This is a stochastic discrete-time discrete-state model with four compartments: “S” for susceptible, “E” for exposed, “I” for infected, and “R” for recovered individuals (the recovered individuals are implicit:
R = population - S - E - I
) with transitionsS -> E -> I -> R
.- Parameters
population (int) – Total
population = S + E + I + R
.incubation_time (float) – Mean incubation time (duration in state
E
). Must be greater than 1.recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.data (iterable) – Time series of new observed infections. Each time step is Binomial distributed between 0 and the number of
S -> E
transitions. This allows false negative but no false positives.
Simple SEIRD¶
- class SimpleSEIRDModel(population, incubation_time, recovery_time, mortality_rate, data)[source]¶
Susceptible-Exposed-Infected-Recovered-Dead model.
To customize this model we recommend forking and editing this class.
This is a stochastic discrete-time discrete-state model with four compartments: “S” for susceptible, “E” for exposed, “I” for infected, “D” for deceased individuals, and “R” for recovered individuals (the recovered individuals are implicit:
R = population - S - E - I - D
) with transitionsS -> E -> I -> R
andI -> D
.Because the transitions are not simple linear succession, this model implements a custom
compute_flows()
method.- Parameters
population (int) – Total
population = S + E + I + R + D
.incubation_time (float) – Mean incubation time (duration in state
E
). Must be greater than 1.recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.mortality_rate (float) – Portion of infections resulting in death. Must be in the open interval
(0, 1)
.data (iterable) – Time series of new observed infections. Each time step is Binomial distributed between 0 and the number of
S -> E
transitions. This allows false negative but no false positives.
Overdispersed SIR¶
- class OverdispersedSIRModel(population, recovery_time, data)[source]¶
Generalizes
SimpleSIRModel
with overdispersed distributions.To customize this model we recommend forking and editing this class.
This adds a single global overdispersion parameter controlling overdispersion of the transition and observation distributions. See
binomial_dist()
andbeta_binomial_dist()
for distributional details. For prior work incorporating overdispersed distributions see [1,2,3,4].References:
- [1] D. Champredon, M. Li, B. Bolker. J. Dushoff (2018)
“Two approaches to forecast Ebola synthetic epidemics” https://www.sciencedirect.com/science/article/pii/S1755436517300233
- [2] Carrie Reed et al. (2015)
“Estimating Influenza Disease Burden from Population-Based Surveillance Data in the United States” https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4349859/
- [3] A. Leonard, D. Weissman, B. Greenbaum, E. Ghedin, K. Koelle (2017)
“Transmission Bottleneck Size Estimation from Pathogen Deep-Sequencing Data, with an Application to Human Influenza A Virus” https://jvi.asm.org/content/jvi/91/14/e00171-17.full.pdf
- [4] A. Miller, N. Foti, J. Lewnard, N. Jewell, C. Guestrin, E. Fox (2020)
“Mobility trends provide a leading indicator of changes in SARS-CoV-2 transmission” https://www.medrxiv.org/content/medrxiv/early/2020/05/11/2020.05.07.20094441.full.pdf
- Parameters
population (int) – Total
population = S + I + R
.recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.data (iterable) – Time series of new observed infections. Each time step is Binomial distributed between 0 and the number of
S -> I
transitions. This allows false negative but no false positives.
Overdispersed SEIR¶
- class OverdispersedSEIRModel(population, incubation_time, recovery_time, data)[source]¶
Generalizes
SimpleSEIRModel
with overdispersed distributions.To customize this model we recommend forking and editing this class.
This adds a single global overdispersion parameter controlling overdispersion of the transition and observation distributions. See
binomial_dist()
andbeta_binomial_dist()
for distributional details. For prior work incorporating overdispersed distributions see [1,2,3,4].References:
- [1] D. Champredon, M. Li, B. Bolker. J. Dushoff (2018)
“Two approaches to forecast Ebola synthetic epidemics” https://www.sciencedirect.com/science/article/pii/S1755436517300233
- [2] Carrie Reed et al. (2015)
“Estimating Influenza Disease Burden from Population-Based Surveillance Data in the United States” https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4349859/
- [3] A. Leonard, D. Weissman, B. Greenbaum, E. Ghedin, K. Koelle (2017)
“Transmission Bottleneck Size Estimation from Pathogen Deep-Sequencing Data, with an Application to Human Influenza A Virus” https://jvi.asm.org/content/jvi/91/14/e00171-17.full.pdf
- [4] A. Miller, N. Foti, J. Lewnard, N. Jewell, C. Guestrin, E. Fox (2020)
“Mobility trends provide a leading indicator of changes in SARS-CoV-2 transmission” https://www.medrxiv.org/content/medrxiv/early/2020/05/11/2020.05.07.20094441.full.pdf
- Parameters
population (int) – Total
population = S + E + I + R
.incubation_time (float) – Mean incubation time (duration in state
E
). Must be greater than 1.recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.data (iterable) – Time series of new observed infections. Each time step is Binomial distributed between 0 and the number of
S -> E
transitions. This allows false negative but no false positives.
Superspreading SIR¶
- class SuperspreadingSIRModel(population, recovery_time, data)[source]¶
Generalizes
SimpleSIRModel
by adding superspreading effects.To customize this model we recommend forking and editing this class.
This model accounts for superspreading (overdispersed individual reproductive number) by assuming each infected individual infects BetaBinomial-many susceptible individuals, where the BetaBinomial distribution acts as an overdispersed Binomial distribution, adapting the more standard NegativeBinomial distribution that acts as an overdispersed Poisson distribution [1,2] to the setting of finite populations. To preserve Markov structure, we follow [2] and assume all infections by a single individual occur on the single time step where that individual makes an
I -> R
transition. That is, whereas theSimpleSIRModel
assumes infected individuals infect Binomial(S,R/tau)-many susceptible individuals during each infected time step (over tau-many steps on average), this model assumes they infect BetaBinomial(k,…,S)-many susceptible individuals but only on the final time step before recovering.References
- [1] J. O. Lloyd-Smith, S. J. Schreiber, P. E. Kopp, W. M. Getz (2005)
“Superspreading and the effect of individual variation on disease emergence” https://www.nature.com/articles/nature04153.pdf
- [2] Lucy M. Li, Nicholas C. Grassly, Christophe Fraser (2017)
“Quantifying Transmission Heterogeneity Using Both Pathogen Phylogenies and Incidence Time Series” https://academic.oup.com/mbe/article/34/11/2982/3952784
- Parameters
population (int) – Total
population = S + I + R
.recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.data (iterable) – Time series of new observed infections. Each time step is Binomial distributed between 0 and the number of
S -> I
transitions. This allows false negative but no false positives.
Superspreading SEIR¶
- class SuperspreadingSEIRModel(population, incubation_time, recovery_time, data, *, leaf_times=None, coal_times=None)[source]¶
Generalizes
SimpleSEIRModel
by adding superspreading effects.To customize this model we recommend forking and editing this class.
This model accounts for superspreading (overdispersed individual reproductive number) by assuming each infected individual infects BetaBinomial-many susceptible individuals, where the BetaBinomial distribution acts as an overdispersed Binomial distribution, adapting the more standard NegativeBinomial distribution that acts as an overdispersed Poisson distribution [1,2] to the setting of finite populations. To preserve Markov structure, we follow [2] and assume all infections by a single individual occur on the single time step where that individual makes an
I -> R
transition. That is, whereas theSimpleSEIRModel
assumes infected individuals infect Binomial(S,R/tau)-many susceptible individuals during each infected time step (over tau-many steps on average), this model assumes they infect BetaBinomial(k,…,S)-many susceptible individuals but only on the final time step before recovering.This model also adds an optional likelihood for observed phylogenetic data in the form of coalescent times. These are provided as a pair
(leaf_times, coal_times)
of times at which genomes are sequenced and lineages coalesce, respectively. We incorporate this data using theCoalescentRateLikelihood
with base coalescence rate computed from theS
andI
populations. This likelihood is independent across time and preserves the Markov propert needed for inference.References
- [1] J. O. Lloyd-Smith, S. J. Schreiber, P. E. Kopp, W. M. Getz (2005)
“Superspreading and the effect of individual variation on disease emergence” https://www.nature.com/articles/nature04153.pdf
- [2] Lucy M. Li, Nicholas C. Grassly, Christophe Fraser (2017)
“Quantifying Transmission Heterogeneity Using Both Pathogen Phylogenies and Incidence Time Series” https://academic.oup.com/mbe/article/34/11/2982/3952784
- Parameters
population (int) – Total
population = S + E + I + R
.incubation_time (float) – Mean incubation time (duration in state
E
). Must be greater than 1.recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.data (iterable) – Time series of new observed infections. Each time step is Binomial distributed between 0 and the number of
S -> E
transitions. This allows false negative but no false positives.
Heterogeneous SIR¶
- class HeterogeneousSIRModel(population, recovery_time, data)[source]¶
Generalizes
SimpleSIRModel
by allowingRt
andrho
to vary in time.To customize this model we recommend forking and editing this class.
In this model, the response rate
rho
is piecewise constant with unknown value over three pieces. The reproductive numberRt
is a product of a constantR0
with a factorbeta
that drifts via Brownian motion in log space. Bothrho
andRt
are available as time series.- Parameters
population (int) – Total
population = S + I + R
.recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.data (iterable) – Time series of new observed infections. Each time step is Binomial distributed between 0 and the number of
S -> I
transitions. This allows false negative but no false positives.
Sparse SIR¶
- class SparseSIRModel(population, recovery_time, data, mask)[source]¶
Generalizes
SimpleSIRModel
to allow sparsely observed infections.To customize this model we recommend forking and editing this class.
This model allows observations of cumulative infections at uneven time intervals. To preserve Markov structure (and hence tractable inference) this model adds an auxiliary compartment
O
denoting the fully-observed cumulative number of observations at each time point. At observed times (whenmask[t] == True
)O
must exactly match the provided data; between observed timesO
stochastically imputes the provided data.This model demonstrates how to implement a custom
compute_flows()
method. A custom method is needed in this model because inhabitants of theS
compartment can transition to both theI
andO
compartments, allowing duplication.- Parameters
population (int) – Total
population = S + I + R
.recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.data (iterable) – Time series of cumulative observed infections. Whenever
mask[t] == True
,data[t]
corresponds to an observation; otherwisedata[t]
can be arbitrary, e.g. NAN.mask (iterable) – Boolean time series denoting whether an observation is made at each time step. Should satisfy
len(mask) == len(data)
.
Unknown Start SIR¶
- class UnknownStartSIRModel(population, recovery_time, pre_obs_window, data)[source]¶
Generalizes
SimpleSIRModel
by allowing unknown date of first infection.To customize this model we recommend forking and editing this class.
This model demonstrates:
How to incorporate spontaneous infections from external sources;
How to incorporate time-varying piecewise
rho
by supporting forecasting intransition()
.How to override the
predict()
method to compute extra statistics.
- Parameters
population (int) – Total
population = S + I + R
.recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.pre_obs_window (int) – Number of time steps before beginning
data
where the initial infection may have occurred. Must be positive.data (iterable) – Time series of new observed infections. Each time step is Binomial distributed between 0 and the number of
S -> I
transitions. This allows false negative but no false positives.
Regional SIR¶
- class RegionalSIRModel(population, coupling, recovery_time, data)[source]¶
Generalizes
SimpleSIRModel
to simultaneously model multiple regions with weak coupling across regions.To customize this model we recommend forking and editing this class.
Regions are coupled by a
coupling
matrix with entries in[0,1]
. The all ones matrix is equivalent to a single region. The identity matrix is equivalent to a set of independent regions. This need not be symmetric, but symmetric matrices are probably more physically plausible. The expected number of new infections each time stepS2I
is Binomial distributed with mean:E[S2I] = S (1 - (1 - R0 / (population @ coupling)) ** (I @ coupling)) ≈ R0 S (I @ coupling) / (population @ coupling) # for small I
Thus in a nearly entirely susceptible population, a single infected individual infects approximately
R0
new individuals on average, independent ofcoupling
.This model demonstrates:
How to create a regional model with a
population
vector.How to model both homogeneous parameters (here
R0
) and heterogeneous parameters with hierarchical structure (hererho
) usingself.region_plate
.How to approximately couple regions in
transition()
usingstate["I_approx"]
.
- Parameters
population (torch.Tensor) – Tensor of per-region populations, defining
population = S + I + R
.coupling (torch.Tensor) – Pairwise coupling matrix. Entries should be in
[0,1]
.recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.data (iterable) – Time x Region sized tensor of new observed infections. Each time step is vector of Binomials distributed between 0 and the number of
S -> I
transitions. This allows false negative but no false positives.
Heterogeneous Regional SIR¶
- class HeterogeneousRegionalSIRModel(population, coupling, recovery_time, data)[source]¶
Generalizes
RegionalSIRModel
by allowingRt
andrho
to vary in time.To customize this model we recommend forking and editing this class.
In this model, the response rate
rho
varies across time and region, whereas the reproductive numberRt
varies in time but is shared among regions. Both parameters drift according to transformed Brownian motion with learned drift rate.This model demonstrates how to model hierarchical latent time series, other than compartmental variables.
- Parameters
population (torch.Tensor) – Tensor of per-region populations, defining
population = S + I + R
.coupling (torch.Tensor) – Pairwise coupling matrix. Entries should be in
[0,1]
.recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.data (iterable) – Time x Region sized tensor of new observed infections. Each time step is vector of Binomials distributed between 0 and the number of
S -> I
transitions. This allows false negative but no false positives.
Distributions¶
- set_approx_sample_thresh(thresh)[source]¶
EXPERIMENTAL Context manager / decorator to temporarily set the global default value of
Binomial.approx_sample_thresh
, thereby decreasing the computational complexity of sampling fromBinomial
,BetaBinomial
,ExtendedBinomial
,ExtendedBetaBinomial
, and distributions returned byinfection_dist()
.This is useful for sampling from very large
total_count
.This is used internally by
CompartmentalModel
.- Parameters
thresh (int or float.) – New temporary threshold.
- set_approx_log_prob_tol(tol)[source]¶
EXPERIMENTAL Context manager / decorator to temporarily set the global default value of
Binomial.approx_log_prob_tol
andBetaBinomial.approx_log_prob_tol
, thereby decreasing the computational complexity of scoringBinomial
andBetaBinomial
distributions.This is used internally by
CompartmentalModel
.- Parameters
tol (int or float.) – New temporary tolold.
- binomial_dist(total_count, probs, *, overdispersion=0.0)[source]¶
Returns a Beta-Binomial distribution that is an overdispersed version of a Binomial distribution, according to a parameter
overdispersion
, typically set in the range 0.1 to 0.5.This is useful for (1) fitting real data that is overdispersed relative to a Binomial distribution, and (2) relaxing models of large populations to improve inference. In particular the
overdispersion
parameter lower bounds the relative uncertainty in stochastic models such that increasing population leads to a limiting scale-free dynamical system with bounded stochasticity, in contrast to Binomial-based SDEs that converge to deterministic ODEs in the large population limit.This parameterization satisfies the following properties:
Variance increases monotonically in
overdispersion
.overdispersion = 0
results in a Binomial distribution.overdispersion
lower bounds the relative uncertaintystd_dev / (total_count * p * q)
, whereprobs = p = 1 - q
, and serves as an asymptote for relative uncertainty astotal_count → ∞
. This contrasts the Binomial whose relative uncertainty tends to zero.If
X ~ binomial_dist(n, p, overdispersion=σ)
then in the large population limitn → ∞
, the scaled random variableX / n
converges in distribution toLogitNormal(log(p/(1-p)), σ)
.
To achieve these properties we set
p = probs
,q = 1 - p
, and:concentration = 1 / (p * q * overdispersion**2) - 1
- Parameters
total_count (int or torch.Tensor) – Number of Bernoulli trials.
probs (float or torch.Tensor) – Event probabilities.
overdispersion (float or torch.tensor) – Amount of overdispersion, in the half open interval [0,2). Defaults to zero.
- beta_binomial_dist(concentration1, concentration0, total_count, *, overdispersion=0.0)[source]¶
Returns a Beta-Binomial distribution that is an overdispersed version of a the usual Beta-Binomial distribution, according to an extra parameter
overdispersion
, typically set in the range 0.1 to 0.5.- Parameters
concentration1 (float or torch.Tensor) – 1st concentration parameter (alpha) for the Beta distribution.
concentration0 (float or torch.Tensor) – 2nd concentration parameter (beta) for the Beta distribution.
total_count (float or torch.Tensor) – Number of Bernoulli trials.
overdispersion (float or torch.tensor) – Amount of overdispersion, in the half open interval [0,2). Defaults to zero.
- infection_dist(*, individual_rate, num_infectious, num_susceptible=inf, population=inf, concentration=inf, overdispersion=0.0)[source]¶
Create a
Distribution
over the number of new infections at a discrete time step.This returns a Poisson, Negative-Binomial, Binomial, or Beta-Binomial distribution depending on whether
population
andconcentration
are finite. In Pyro models, the population is usually finite. In the limitpopulation → ∞
andnum_susceptible/population → 1
, the Binomial converges to Poisson and the Beta-Binomial converges to Negative-Binomial. In the limitconcentration → ∞
, the Negative-Binomial converges to Poisson and the Beta-Binomial converges to Binomial.The overdispersed distributions (Negative-Binomial and Beta-Binomial returned when
concentration < ∞
) are useful for modeling superspreader individuals [1,2]. The finitely supported distributions Binomial and Negative-Binomial are useful in small populations and in probabilistic programming systems where truncation or censoring are expensive [3].References
- [1] J. O. Lloyd-Smith, S. J. Schreiber, P. E. Kopp, W. M. Getz (2005)
“Superspreading and the effect of individual variation on disease emergence” https://www.nature.com/articles/nature04153.pdf
- [2] Lucy M. Li, Nicholas C. Grassly, Christophe Fraser (2017)
“Quantifying Transmission Heterogeneity Using Both Pathogen Phylogenies and Incidence Time Series” https://academic.oup.com/mbe/article/34/11/2982/3952784
- [3] Lawrence Murray et al. (2018)
“Delayed Sampling and Automatic Rao-Blackwellization of Probabilistic Programs” https://arxiv.org/pdf/1708.07787.pdf
- Parameters
individual_rate – The mean number of infections per infectious individual per time step in the limit of large population, equal to
R0 / tau
whereR0
is the basic reproductive number andtau
is the mean duration of infectiousness.num_infectious – The number of infectious individuals at this time step, sometimes
I
, sometimesE+I
.num_susceptible – The number
S
of susceptible individuals at this time step. This defaults to an infinite population.population – The total number of individuals in a population. This defaults to an infinite population.
concentration – The concentration or dispersion parameter
k
in overdispersed models of superspreaders [1,2]. This defaults to minimum varianceconcentration = ∞
.overdispersion (float or torch.tensor) – Amount of overdispersion, in the half open interval [0,2). Defaults to zero.
- class CoalescentRateLikelihood(leaf_times, coal_times, duration, *, validate_args=None)[source]¶
Bases:
object
EXPERIMENTAL This is not a
Distribution
, but acts as a transposed version ofCoalescentTimesWithRate
making the elements ofrate_grid
independent and thus compatible withplate
andpoutine.markov
. For non-batched inputs the following are all equivalent likelihoods:# Version 1. pyro.sample("coalescent", CoalescentTimesWithRate(leaf_times, rate_grid), obs=coal_times) # Version 2. using pyro.plate likelihood = CoalescentRateLikelihood(leaf_times, coal_times, len(rate_grid)) with pyro.plate("time", len(rate_grid)): pyro.factor("coalescent", likelihood(rate_grid)) # Version 3. using pyro.markov likelihood = CoalescentRateLikelihood(leaf_times, coal_times, len(rate_grid)) for t in pyro.markov(range(len(rate_grid))): pyro.factor("coalescent_{}".format(t), likelihood(rate_grid[t], t))
The third version is useful for e.g.
SMCFilter
whererate_grid
might be computed sequentially.- Parameters
leaf_times (torch.Tensor) – Tensor of times of sampling events, i.e. leaf nodes in the phylogeny. These can be arbitrary real numbers with arbitrary order and duplicates.
coal_times (torch.Tensor) – A tensor of coalescent times. These denote sets of size
leaf_times.size(-1) - 1
along the trailing dimension and should be sorted along that dimension.duration (int) – Size of the rate grid,
rate_grid.size(-1)
.
- __call__(rate_grid, t=slice(None, None, None))[source]¶
Computes the likelihood of [1] equations 7-9 for one or all time points.
References
- [1] A. Popinga, T. Vaughan, T. Statler, A.J. Drummond (2014)
“Inferring epidemiological dynamics with Bayesian coalescent inference: The merits of deterministic and stochastic models” https://arxiv.org/pdf/1407.1792.pdf
- Parameters
rate_grid (torch.Tensor) – Tensor of base coalescent rates (pairwise rate of coalescence). For example in a simple SIR model this might be
beta S / I
. The rightmost dimension is time, and this tensor represents a (batch of) rates that are piecwise constant in time.time (int or slice) – Optional time index by which the input was sliced, as in
rate_grid[..., t]
This can be an integer for sequential models orslice(None)
for vectorized models.
- Returns
Likelihood
p(coal_times | leaf_times, rate_grid)
, or a part of that likelihood corresponding to a single time step.- Return type
- bio_phylo_to_times(tree, *, get_time=None)[source]¶
Extracts coalescent summary statistics from a phylogeny, suitable for use with
CoalescentRateLikelihood
.- Parameters
tree (Bio.Phylo.BaseTree.Clade) – A phylogenetic tree.
get_time (callable) – Optional function to extract the time point of each sub-
Clade
. If absent, times will be computed by cumulative .branch_length.
- Returns
A pair of
Tensor
s(leaf_times, coal_times)
whereleaf_times
are times of sampling events (leaf nodes in the phylogenetic tree) andcoal_times
are times of coalescences (leaf nodes in the phylogenetic binary tree).- Return type