Epidemiology¶
Warning
Code in pyro.contrib.epidemiology
is under development.
This code makes no guarantee about maintaining backwards compatibility.
pyro.contrib.epidemiology
provides a modeling language for a class of
stochastic discretetime discretecount compartmental models. This module
implements blackbox inference (both Stochastic Variational Inference and
Hamiltonian Monte Carlo), prediction of latent variables, and
forecasting of future trajectories.
For example usage see the following tutorials:
Base Compartmental Model¶

class
CompartmentalModel
(compartments, duration, population, *, approximate=())[source]¶ Bases:
abc.ABC
Abstract base class for discretetime discretevalue stochastic compartmental models.
Derived classes must implement methods
initialize()
andtransition()
. Derived classes may optionally implementglobal_model()
,compute_flows()
, andheuristic()
.Example usage:
# First implement a concrete derived class. class MyModel(CompartmentalModel): def __init__(self, ...): ... def global_model(self): ... def initialize(self, params): ... def transition(self, params, state, t): ... # Run inference to fit the model to data. model = MyModel(...) model.fit_svi(num_samples=100) # or .fit_mcmc(...) R0 = model.samples["R0"] # An example parameter. print("R0 = {:0.3g} ± {:0.3g}".format(R0.mean(), R0.std())) # Predict latent variables. samples = model.predict() # Forecast forward. samples = model.predict(forecast=30) # You can assess future interventions (applied after ``duration``) by # storing them as attributes that are read by your derived methods. model.my_intervention = False samples1 = model.predict(forecast=30) model.my_intervention = True samples2 = model.predict(forecast=30) effect = samples2["my_result"].mean()  samples1["my_result"].mean() print("average effect = {:0.3g}".format(effect))
An example workflow is to use cheaper approximate inference while finding good model structure and priors, then move to more accurate but more expensive inference once the model is plausible.
 Start with
.fit_svi(guide_rank=0, num_steps=2000)
for cheap inference while you search for a good model.  Additionally infer longrange correlations by moving to a lowrank
multivariate normal guide via
.fit_svi(guide_rank=None, num_steps=5000)
.  Optionally additionally infer nonGaussian posterior by moving to the
more expensive (but still approximate via moment matching)
.fit_mcmc(num_quant_bins=1, num_samples=10000, num_chains=2)
.  Optionally improve fit around small counts by moving the the more
expensive enumerationbased algorithm
.fit_mcmc(num_quant_bins=4, num_samples=10000, num_chains=2)
(GPU recommended).
Variables: samples (dict) – Dictionary of posterior samples.
Parameters:  compartments (list) – A list of strings of compartment names.
 duration (int) – The number of discrete time steps in this model.
 population (int or torch.Tensor) – Either the total population of a singleregion model or a tensor of each region’s population in a regional model.
 approximate (tuple) – Names of compartments for which pointwise
approximations should be provided in
transition()
, e.g. if you specifyapproximate=("I")
then thestate["I_approx"]
will be a continuousvalued nonenumerated point estimate ofstate["I"]
. Approximations are useful to reduce computational cost. Approximations are continuousvalued with support(0.5, population + 0.5)
.

time_plate
¶ A
pyro.plate
for the time dimension.

region_plate
¶ Either a
pyro.plate
or a trivialExitStack
depending on whether this model.is_regional
.

global_model
()[source]¶ Samples and returns any global parameters.
Returns: An arbitrary object of parameters (e.g. None
or a tuple).

initialize
(params)[source]¶ Returns initial counts in each compartment.
Parameters: params – The global params returned by global_model()
.Returns: A dict mapping compartment name to initial value. Return type: dict

transition
(params, state, t)[source]¶ Forward generative process for dynamics.
This inputs a current
state
and stochastically updates that state inplace.Note that this method is called under multiple different interpretations, including batched and vectorized interpretations. During
generate()
this is called to generate a single sample. Duringheuristic()
this is called to generate a batch of samples for SMC. Duringfit_mcmc()
this is called both in vectorized form (vectorizing over time) and in sequential form (for a single time step); both forms enumerate over discrete latent variables. Duringpredict()
this is called to forecast a batch of samples, conditioned on posterior samples for the time interval[0:duration]
.Parameters:  params – The global params returned by
global_model()
.  state (dict) – A dictionary mapping compartment name to current tensor value. This should be updated inplace.
 t (int or slice) – A timelike index. During inference
t
may be either a slice (for vectorized inference) or an integer time index. During predictiont
will be integer time index.
 params – The global params returned by

compute_flows
(prev, curr, t)[source]¶ Computes flows between compartments, given compartment populations before and after time step t.
The default implementation assumes sequential flows terminating in an implicit compartment named “R”. For example if:
compartment_names = ("S", "E", "I")
the default implementation computes at time step
t = 9
:flows["S2E_9"] = prev["S"]  curr["S"] flows["E2I_9"] = prev["E"]  curr["E"] + flows["S2E_9"] flows["I2R_9"] = prev["I"]  curr["I"] + flows["E2I_9"]
For more complex flows (nonsequential, branching, looping, duplicating, etc.), users may override this method.
Parameters: Returns: A dict mapping flow name to tensor value.
Return type:

generate
(fixed={})[source]¶ Generate data from the prior.
Pram dict fixed: A dictionary of parameters on which to condition. These must be toplevel parentless nodes, i.e. have no upstream stochastic dependencies. Returns: A dictionary mapping sample site name to sampled value. Return type: dict

fit_svi
(*, num_samples=100, num_steps=2000, num_particles=32, learning_rate=0.1, learning_rate_decay=0.01, betas=(0.8, 0.99), haar=True, init_scale=0.01, guide_rank=0, jit=False, log_every=200, **options)[source]¶ Runs stochastic variational inference to generate posterior samples.
This runs
SVI
, setting the.samples
attribute on completion.This approximate inference method is useful for quickly iterating on probabilistic models.
Parameters:  num_samples (int) – Number of posterior samples to draw from the trained guide. Defaults to 100.
 num_steps (int) – Number of
SVI
steps.  num_particles (int) – Number of
SVI
particles per step.  learning_rate (int) – Learning rate for the
ClippedAdam
optimizer.  learning_rate_decay (int) – Learning rate for the
ClippedAdam
optimizer. Note this is decay over the entire schedule, not perstep decay.  betas (tuple) – Momentum parameters for the
ClippedAdam
optimizer.  haar (bool) – Whether to use a Haar wavelet reparameterizer.
 guide_rank (int) – Rank of the auto normal guide. If zero (default)
use an
AutoNormal
guide. If a positive integer or None, use anAutoLowRankMultivariateNormal
guide. If the string “full”, use anAutoMultivariateNormal
guide. These latter two require morenum_steps
to fit.  init_scale (float) – Initial scale of the
AutoLowRankMultivariateNormal
guide.  jit (bool) – Whether to use a jit compiled ELBO.
 log_every (int) – How often to log svi losses.
 heuristic_num_particles (int) – Passed to
heuristic()
asnum_particles
. Defaults to 1024.
Returns: Time series of SVI losses (useful to diagnose convergence).
Return type:

fit_mcmc
(**options)[source]¶ Runs NUTS inference to generate posterior samples.
This uses the
NUTS
kernel to runMCMC
, setting the.samples
attribute on completion.This uses an asymptotically exact enumerationbased model when
num_quant_bins > 1
, and a cheaper momentmatched approximate model whennum_quant_bins == 1
.Parameters:  **options – Options passed to
MCMC
. The remaining options are pulled out and have special meaning.  num_samples (int) – Number of posterior samples to draw via mcmc. Defaults to 100.
 max_tree_depth (int) – (Default 5). Max tree depth of the
NUTS
kernel.  full_mass – Specification of mass matrix of the
NUTS
kernel. Defaults to full mass over global random variables.  arrowhead_mass (bool) – Whether to treat
full_mass
as the head of an arrowhead matrix versus simply as a block. Defaults to False.  num_quant_bins (int) – If greater than 1, use asymptotically exact inference via local enumeration over this many quantization bins. If equal to 1, use continuousvalued relaxed approximate inference. Note that computational cost is exponential in num_quant_bins. Defaults to 1 for relaxed inference.
 haar (bool) – Whether to use a Haar wavelet reparameterizer. Defaults to True.
 haar_full_mass (int) – Number of low frequency Haar components to
include in the full mass matrix. If
haar=False
then this is ignored. Defaults to 10.  heuristic_num_particles (int) – Passed to
heuristic()
asnum_particles
. Defaults to 1024.
Returns: An MCMC object for diagnostics, e.g.
MCMC.summary()
.Return type:  **options – Options passed to

predict
(forecast=0)[source]¶ Predict latent variables and optionally forecast forward.
This may be run only after
fit_mcmc()
and draws the samenum_samples
as passed tofit_mcmc()
.Parameters: forecast (int) – The number of time steps to forecast forward. Returns: A dictionary mapping sample site name (or compartment name) to a tensor whose first dimension corresponds to sample batching. Return type: dict

heuristic
(num_particles=1024, ess_threshold=0.5, retries=10)[source]¶ Finds an initial feasible guess of all latent variables, consistent with observed data. This is needed because not all hypotheses are feasible and HMC needs to start at a feasible solution to progress.
The default implementation attempts to find a feasible state using
SMCFilter
with proprosals from the prior. However this method may be overridden in cases where SMC performs poorly e.g. in highdimensional models.Parameters: Returns: A dictionary mapping sample site name to tensor value.
Return type:
 Start with
Example Models¶
Simple SIR¶

class
SimpleSIRModel
(population, recovery_time, data)[source]¶ SusceptibleInfectedRecovered model.
To customize this model we recommend forking and editing this class.
This is a stochastic discretetime discretestate model with three compartments: “S” for susceptible, “I” for infected, and “R” for recovered individuals (the recovered individuals are implicit:
R = population  S  I
) with transitionsS > I > R
.Parameters:  population (int) – Total
population = S + I + R
.  recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.  data (iterable) – Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of
S > I
transitions. This allows false negative but no false positives.
 population (int) – Total
Simple SEIR¶

class
SimpleSEIRModel
(population, incubation_time, recovery_time, data)[source]¶ SusceptibleExposedInfectedRecovered model.
To customize this model we recommend forking and editing this class.
This is a stochastic discretetime discretestate model with four compartments: “S” for susceptible, “E” for exposed, “I” for infected, and “R” for recovered individuals (the recovered individuals are implicit:
R = population  S  E  I
) with transitionsS > E > I > R
.Parameters:  population (int) – Total
population = S + E + I + R
.  incubation_time (float) – Mean incubation time (duration in state
E
). Must be greater than 1.  recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.  data (iterable) – Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of
S > E
transitions. This allows false negative but no false positives.
 population (int) – Total
Simple SEIRD¶

class
SimpleSEIRDModel
(population, incubation_time, recovery_time, mortality_rate, data)[source]¶ SusceptibleExposedInfectedRecoveredDead model.
To customize this model we recommend forking and editing this class.
This is a stochastic discretetime discretestate model with four compartments: “S” for susceptible, “E” for exposed, “I” for infected, “D” for deceased individuals, and “R” for recovered individuals (the recovered individuals are implicit:
R = population  S  E  I  D
) with transitionsS > E > I > R
andI > D
.Because the transitions are not simple linear succession, this model implements a custom
compute_flows()
method.Parameters:  population (int) – Total
population = S + E + I + R + D
.  incubation_time (float) – Mean incubation time (duration in state
E
). Must be greater than 1.  recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.  mortality_rate (float) – Portion of infections resulting in death.
Must be in the open interval
(0, 1)
.  data (iterable) – Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of
S > E
transitions. This allows false negative but no false positives.
 population (int) – Total
Overdispersed SIR¶

class
OverdispersedSIRModel
(population, recovery_time, data)[source]¶ Generalizes
SimpleSIRModel
with overdispersed distributions.To customize this model we recommend forking and editing this class.
This adds a single global overdispersion parameter controlling overdispersion of the transition and observation distributions. See
binomial_dist()
andbeta_binomial_dist()
for distributional details. For prior work incorporating overdispersed distributions see [1,2,3,4].References:
 [1] D. Champredon, M. Li, B. Bolker. J. Dushoff (2018)
 “Two approaches to forecast Ebola synthetic epidemics” https://www.sciencedirect.com/science/article/pii/S1755436517300233
 [2] Carrie Reed et al. (2015)
 “Estimating Influenza Disease Burden from PopulationBased Surveillance Data in the United States” https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4349859/
 [3] A. Leonard, D. Weissman, B. Greenbaum, E. Ghedin, K. Koelle (2017)
 “Transmission Bottleneck Size Estimation from Pathogen DeepSequencing Data, with an Application to Human Influenza A Virus” https://jvi.asm.org/content/jvi/91/14/e0017117.full.pdf
 [4] A. Miller, N. Foti, J. Lewnard, N. Jewell, C. Guestrin, E. Fox (2020)
 “Mobility trends provide a leading indicator of changes in SARSCoV2 transmission” https://www.medrxiv.org/content/medrxiv/early/2020/05/11/2020.05.07.20094441.full.pdf
Parameters:  population (int) – Total
population = S + I + R
.  recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.  data (iterable) – Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of
S > I
transitions. This allows false negative but no false positives.
Overdispersed SEIR¶

class
OverdispersedSEIRModel
(population, incubation_time, recovery_time, data)[source]¶ Generalizes
SimpleSEIRModel
with overdispersed distributions.To customize this model we recommend forking and editing this class.
This adds a single global overdispersion parameter controlling overdispersion of the transition and observation distributions. See
binomial_dist()
andbeta_binomial_dist()
for distributional details. For prior work incorporating overdispersed distributions see [1,2,3,4].References:
 [1] D. Champredon, M. Li, B. Bolker. J. Dushoff (2018)
 “Two approaches to forecast Ebola synthetic epidemics” https://www.sciencedirect.com/science/article/pii/S1755436517300233
 [2] Carrie Reed et al. (2015)
 “Estimating Influenza Disease Burden from PopulationBased Surveillance Data in the United States” https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4349859/
 [3] A. Leonard, D. Weissman, B. Greenbaum, E. Ghedin, K. Koelle (2017)
 “Transmission Bottleneck Size Estimation from Pathogen DeepSequencing Data, with an Application to Human Influenza A Virus” https://jvi.asm.org/content/jvi/91/14/e0017117.full.pdf
 [4] A. Miller, N. Foti, J. Lewnard, N. Jewell, C. Guestrin, E. Fox (2020)
 “Mobility trends provide a leading indicator of changes in SARSCoV2 transmission” https://www.medrxiv.org/content/medrxiv/early/2020/05/11/2020.05.07.20094441.full.pdf
Parameters:  population (int) – Total
population = S + E + I + R
.  incubation_time (float) – Mean incubation time (duration in state
E
). Must be greater than 1.  recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.  data (iterable) – Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of
S > E
transitions. This allows false negative but no false positives.
Superspreading SIR¶

class
SuperspreadingSIRModel
(population, recovery_time, data)[source]¶ Generalizes
SimpleSIRModel
by adding superspreading effects.To customize this model we recommend forking and editing this class.
This model accounts for superspreading (overdispersed individual reproductive number) by assuming each infected individual infects BetaBinomialmany susceptible individuals, where the BetaBinomial distribution acts as an overdispersed Binomial distribution, adapting the more standard NegativeBinomial distribution that acts as an overdispersed Poisson distribution [1,2] to the setting of finite populations. To preserve Markov structure, we follow [2] and assume all infections by a single individual occur on the single time step where that individual makes an
I > R
transition. That is, whereas theSimpleSIRModel
assumes infected individuals infect Binomial(S,R/tau)many susceptible individuals during each infected time step (over taumany steps on average), this model assumes they infect BetaBinomial(k,…,S)many susceptible individuals but only on the final time step before recovering.References
 [1] J. O. LloydSmith, S. J. Schreiber, P. E. Kopp, W. M. Getz (2005)
 “Superspreading and the effect of individual variation on disease emergence” https://www.nature.com/articles/nature04153.pdf
 [2] Lucy M. Li, Nicholas C. Grassly, Christophe Fraser (2017)
 “Quantifying Transmission Heterogeneity Using Both Pathogen Phylogenies and Incidence Time Series” https://academic.oup.com/mbe/article/34/11/2982/3952784
Parameters:  population (int) – Total
population = S + I + R
.  recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.  data (iterable) – Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of
S > I
transitions. This allows false negative but no false positives.
Superspreading SEIR¶

class
SuperspreadingSEIRModel
(population, incubation_time, recovery_time, data, *, leaf_times=None, coal_times=None)[source]¶ Generalizes
SimpleSEIRModel
by adding superspreading effects.To customize this model we recommend forking and editing this class.
This model accounts for superspreading (overdispersed individual reproductive number) by assuming each infected individual infects BetaBinomialmany susceptible individuals, where the BetaBinomial distribution acts as an overdispersed Binomial distribution, adapting the more standard NegativeBinomial distribution that acts as an overdispersed Poisson distribution [1,2] to the setting of finite populations. To preserve Markov structure, we follow [2] and assume all infections by a single individual occur on the single time step where that individual makes an
I > R
transition. That is, whereas theSimpleSEIRModel
assumes infected individuals infect Binomial(S,R/tau)many susceptible individuals during each infected time step (over taumany steps on average), this model assumes they infect BetaBinomial(k,…,S)many susceptible individuals but only on the final time step before recovering.This model also adds an optional likelihood for observed phylogenetic data in the form of coalescent times. These are provided as a pair
(leaf_times, coal_times)
of times at which genomes are sequenced and lineages coalesce, respectively. We incorporate this data using theCoalescentRateLikelihood
with base coalescence rate computed from theS
andI
populations. This likelihood is independent across time and preserves the Markov propert needed for inference.References
 [1] J. O. LloydSmith, S. J. Schreiber, P. E. Kopp, W. M. Getz (2005)
 “Superspreading and the effect of individual variation on disease emergence” https://www.nature.com/articles/nature04153.pdf
 [2] Lucy M. Li, Nicholas C. Grassly, Christophe Fraser (2017)
 “Quantifying Transmission Heterogeneity Using Both Pathogen Phylogenies and Incidence Time Series” https://academic.oup.com/mbe/article/34/11/2982/3952784
Parameters:  population (int) – Total
population = S + E + I + R
.  incubation_time (float) – Mean incubation time (duration in state
E
). Must be greater than 1.  recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.  data (iterable) – Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of
S > E
transitions. This allows false negative but no false positives.
Heterogeneous SIR¶

class
HeterogeneousSIRModel
(population, recovery_time, data)[source]¶ Generalizes
SimpleSIRModel
by allowingRt
andrho
to vary in time.To customize this model we recommend forking and editing this class.
In this model, the response rate
rho
is piecewise constant with unknown value over three pieces. The reproductive numberRt
is a product of a constantR0
with a factorbeta
that drifts via Brownian motion in log space. Bothrho
andRt
are available as time series.Parameters:  population (int) – Total
population = S + I + R
.  recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.  data (iterable) – Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of
S > I
transitions. This allows false negative but no false positives.
 population (int) – Total
Sparse SIR¶

class
SparseSIRModel
(population, recovery_time, data, mask)[source]¶ Generalizes
SimpleSIRModel
to allow sparsely observed infections.To customize this model we recommend forking and editing this class.
This model allows observations of cumulative infections at uneven time intervals. To preserve Markov structure (and hence tractable inference) this model adds an auxiliary compartment
O
denoting the fullyobserved cumulative number of observations at each time point. At observed times (whenmask[t] == True
)O
must exactly match the provided data; between observed timesO
stochastically imputes the provided data.This model demonstrates how to implement a custom
compute_flows()
method. A custom method is needed in this model because inhabitants of theS
compartment can transition to both theI
andO
compartments, allowing duplication.Parameters:  population (int) – Total
population = S + I + R
.  recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.  data (iterable) – Time series of cumulative observed infections.
Whenever
mask[t] == True
,data[t]
corresponds to an observation; otherwisedata[t]
can be arbitrary, e.g. NAN.  mask (iterable) – Boolean time series denoting whether an observation
is made at each time step. Should satisfy
len(mask) == len(data)
.
 population (int) – Total
Unknown Start SIR¶

class
UnknownStartSIRModel
(population, recovery_time, pre_obs_window, data)[source]¶ Generalizes
SimpleSIRModel
by allowing unknown date of first infection.To customize this model we recommend forking and editing this class.
This model demonstrates:
 How to incorporate spontaneous infections from external sources;
 How to incorporate timevarying piecewise
rho
by supporting forecasting intransition()
.  How to override the
predict()
method to compute extra statistics.
Parameters:  population (int) – Total
population = S + I + R
.  recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.  pre_obs_window (int) – Number of time steps before beginning
data
where the initial infection may have occurred. Must be positive.  data (iterable) – Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of
S > I
transitions. This allows false negative but no false positives.
Regional SIR¶

class
RegionalSIRModel
(population, coupling, recovery_time, data)[source]¶ Generalizes
SimpleSIRModel
to simultaneously model multiple regions with weak coupling across regions.To customize this model we recommend forking and editing this class.
Regions are coupled by a
coupling
matrix with entries in[0,1]
. The all ones matrix is equivalent to a single region. The identity matrix is equivalent to a set of independent regions. This need not be symmetric, but symmetric matrices are probably more physically plausible. The expected number of new infections each time stepS2I
is Binomial distributed with mean:E[S2I] = S (1  (1  R0 / (population @ coupling)) ** (I @ coupling)) ≈ R0 S (I @ coupling) / (population @ coupling) # for small I
Thus in a nearly entirely susceptible population, a single infected individual infects approximately
R0
new individuals on average, independent ofcoupling
.This model demonstrates:
 How to create a regional model with a
population
vector.  How to model both homogeneous parameters (here
R0
) and heterogeneous parameters with hierarchical structure (hererho
) usingself.region_plate
.  How to approximately couple regions in
transition()
usingstate["I_approx"]
.
Parameters:  population (torch.Tensor) – Tensor of perregion populations, defining
population = S + I + R
.  coupling (torch.Tensor) – Pairwise coupling matrix. Entries should be
in
[0,1]
.  recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.  data (iterable) – Time x Region sized tensor of new observed
infections. Each time step is vector of Binomials distributed between
0 and the number of
S > I
transitions. This allows false negative but no false positives.
 How to create a regional model with a
Heterogeneous Regional SIR¶

class
HeterogeneousRegionalSIRModel
(population, coupling, recovery_time, data)[source]¶ Generalizes
RegionalSIRModel
by allowingRt
andrho
to vary in time.To customize this model we recommend forking and editing this class.
In this model, the response rate
rho
varies across time and region, whereas the reproductive numberRt
varies in time but is shared among regions. Both parameters drift according to transformed Brownian motion with learned drift rate.This model demonstrates how to model hierarchical latent time series, other than compartmental variables.
Parameters:  population (torch.Tensor) – Tensor of perregion populations, defining
population = S + I + R
.  coupling (torch.Tensor) – Pairwise coupling matrix. Entries should be
in
[0,1]
.  recovery_time (float) – Mean recovery time (duration in state
I
). Must be greater than 1.  data (iterable) – Time x Region sized tensor of new observed
infections. Each time step is vector of Binomials distributed between
0 and the number of
S > I
transitions. This allows false negative but no false positives.
 population (torch.Tensor) – Tensor of perregion populations, defining
Distributions¶

set_approx_sample_thresh
(thresh)[source]¶ EXPERIMENTAL Context manager / decorator to temporarily set the global default value of
Binomial.approx_sample_thresh
, thereby decreasing the computational complexity of sampling fromBinomial
,BetaBinomial
,ExtendedBinomial
,ExtendedBetaBinomial
, and distributions returned byinfection_dist()
.This is useful for sampling from very large
total_count
.This is used internally by
CompartmentalModel
.Parameters: thresh (int or float.) – New temporary threshold.

set_approx_log_prob_tol
(tol)[source]¶ EXPERIMENTAL Context manager / decorator to temporarily set the global default value of
Binomial.approx_log_prob_tol
andBetaBinomial.approx_log_prob_tol
, thereby decreasing the computational complexity of scoringBinomial
andBetaBinomial
distributions.This is used internally by
CompartmentalModel
.Parameters: tol (int or float.) – New temporary tolold.

binomial_dist
(total_count, probs, *, overdispersion=0.0)[source]¶ Returns a BetaBinomial distribution that is an overdispersed version of a Binomial distribution, according to a parameter
overdispersion
, typically set in the range 0.1 to 0.5.This is useful for (1) fitting real data that is overdispersed relative to a Binomial distribution, and (2) relaxing models of large populations to improve inference. In particular the
overdispersion
parameter lower bounds the relative uncertainty in stochastic models such that increasing population leads to a limiting scalefree dynamical system with bounded stochasticity, in contrast to Binomialbased SDEs that converge to deterministic ODEs in the large population limit.This parameterization satisfies the following properties:
 Variance increases monotonically in
overdispersion
. overdispersion = 0
results in a Binomial distribution.overdispersion
lower bounds the relative uncertaintystd_dev / (total_count * p * q)
, whereprobs = p = 1  q
, and serves as an asymptote for relative uncertainty astotal_count → ∞
. This contrasts the Binomial whose relative uncertainty tends to zero. If
X ~ binomial_dist(n, p, overdispersion=σ)
then in the large population limitn → ∞
, the scaled random variableX / n
converges in distribution toLogitNormal(log(p/(1p)), σ)
.
To achieve these properties we set
p = probs
,q = 1  p
, and:concentration = 1 / (p * q * overdispersion**2)  1
Parameters:  total_count (int or torch.Tensor) – Number of Bernoulli trials.
 probs (float or torch.Tensor) – Event probabilities.
 overdispersion (float or torch.tensor) – Amount of overdispersion, in the half open interval [0,2). Defaults to zero.
 Variance increases monotonically in

beta_binomial_dist
(concentration1, concentration0, total_count, *, overdispersion=0.0)[source]¶ Returns a BetaBinomial distribution that is an overdispersed version of a the usual BetaBinomial distribution, according to an extra parameter
overdispersion
, typically set in the range 0.1 to 0.5.Parameters:  concentration1 (float or torch.Tensor) – 1st concentration parameter (alpha) for the Beta distribution.
 concentration0 (float or torch.Tensor) – 2nd concentration parameter (beta) for the Beta distribution.
 total_count (float or torch.Tensor) – Number of Bernoulli trials.
 overdispersion (float or torch.tensor) – Amount of overdispersion, in the half open interval [0,2). Defaults to zero.

infection_dist
(*, individual_rate, num_infectious, num_susceptible=inf, population=inf, concentration=inf, overdispersion=0.0)[source]¶ Create a
Distribution
over the number of new infections at a discrete time step.This returns a Poisson, NegativeBinomial, Binomial, or BetaBinomial distribution depending on whether
population
andconcentration
are finite. In Pyro models, the population is usually finite. In the limitpopulation → ∞
andnum_susceptible/population → 1
, the Binomial converges to Poisson and the BetaBinomial converges to NegativeBinomial. In the limitconcentration → ∞
, the NegativeBinomial converges to Poisson and the BetaBinomial converges to Binomial.The overdispersed distributions (NegativeBinomial and BetaBinomial returned when
concentration < ∞
) are useful for modeling superspreader individuals [1,2]. The finitely supported distributions Binomial and NegativeBinomial are useful in small populations and in probabilistic programming systems where truncation or censoring are expensive [3].References
 [1] J. O. LloydSmith, S. J. Schreiber, P. E. Kopp, W. M. Getz (2005)
 “Superspreading and the effect of individual variation on disease emergence” https://www.nature.com/articles/nature04153.pdf
 [2] Lucy M. Li, Nicholas C. Grassly, Christophe Fraser (2017)
 “Quantifying Transmission Heterogeneity Using Both Pathogen Phylogenies and Incidence Time Series” https://academic.oup.com/mbe/article/34/11/2982/3952784
 [3] Lawrence Murray et al. (2018)
 “Delayed Sampling and Automatic RaoBlackwellization of Probabilistic Programs” https://arxiv.org/pdf/1708.07787.pdf
Parameters:  individual_rate – The mean number of infections per infectious
individual per time step in the limit of large population, equal to
R0 / tau
whereR0
is the basic reproductive number andtau
is the mean duration of infectiousness.  num_infectious – The number of infectious individuals at this
time step, sometimes
I
, sometimesE+I
.  num_susceptible – The number
S
of susceptible individuals at this time step. This defaults to an infinite population.  population – The total number of individuals in a population. This defaults to an infinite population.
 concentration – The concentration or dispersion parameter
k
in overdispersed models of superspreaders [1,2]. This defaults to minimum varianceconcentration = ∞
.  overdispersion (float or torch.tensor) – Amount of overdispersion, in the half open interval [0,2). Defaults to zero.

class
CoalescentRateLikelihood
(leaf_times, coal_times, duration, *, validate_args=None)[source]¶ Bases:
object
EXPERIMENTAL This is not a
Distribution
, but acts as a transposed version ofCoalescentTimesWithRate
making the elements ofrate_grid
independent and thus compatible withplate
andpoutine.markov
. For nonbatched inputs the following are all equivalent likelihoods:# Version 1. pyro.sample("coalescent", CoalescentTimesWithRate(leaf_times, rate_grid), obs=coal_times) # Version 2. using pyro.plate likelihood = CoalescentRateLikelihood(leaf_times, coal_times, len(rate_grid)) with pyro.plate("time", len(rate_grid)): pyro.factor("coalescent", likelihood(rate_grid)) # Version 3. using pyro.markov likelihood = CoalescentRateLikelihood(leaf_times, coal_times, len(rate_grid)) for t in pyro.markov(range(len(rate_grid))): pyro.factor("coalescent_{}".format(t), likelihood(rate_grid[t], t))
The third version is useful for e.g.
SMCFilter
whererate_grid
might be computed sequentially.Parameters:  leaf_times (torch.Tensor) – Tensor of times of sampling events, i.e. leaf nodes in the phylogeny. These can be arbitrary real numbers with arbitrary order and duplicates.
 coal_times (torch.Tensor) – A tensor of coalescent times. These denote
sets of size
leaf_times.size(1)  1
along the trailing dimension and should be sorted along that dimension.  duration (int) – Size of the rate grid,
rate_grid.size(1)
.

__call__
(rate_grid, t=slice(None, None, None))[source]¶ Computes the likelihood of [1] equations 79 for one or all time points.
References
 [1] A. Popinga, T. Vaughan, T. Statler, A.J. Drummond (2014)
 “Inferring epidemiological dynamics with Bayesian coalescent inference: The merits of deterministic and stochastic models” https://arxiv.org/pdf/1407.1792.pdf
Parameters:  rate_grid (torch.Tensor) – Tensor of base coalescent rates
(pairwise rate of coalescence). For example in a simple SIR model
this might be
beta S / I
. The rightmost dimension is time, and this tensor represents a (batch of) rates that are piecwise constant in time.  time (int or slice) – Optional time index by which the input was sliced, as in
rate_grid[..., t]
This can be an integer for sequential models orslice(None)
for vectorized models.
Returns: Likelihood
p(coal_times  leaf_times, rate_grid)
, or a part of that likelihood corresponding to a single time step.Return type:

bio_phylo_to_times
(tree, *, get_time=None)[source]¶ Extracts coalescent summary statistics from a phylogeny, suitable for use with
CoalescentRateLikelihood
.Parameters:  tree (Bio.Phylo.BaseTree.Clade) – A phylogenetic tree.
 get_time (callable) – Optional function to extract the time point of
each sub
Clade
. If absent, times will be computed by cumulative .branch_length.
Returns: A pair of
Tensor
s(leaf_times, coal_times)
whereleaf_times
are times of sampling events (leaf nodes in the phylogenetic tree) andcoal_times
are times of coalescences (leaf nodes in the phylogenetic binary tree).Return type: