# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import re
import torch
from torch.nn.functional import pad
import pyro
import pyro.distributions as dist
from .compartmental import CompartmentalModel
from .distributions import binomial_dist, infection_dist
[docs]class SimpleSIRModel(CompartmentalModel):
"""
Susceptible-Infected-Recovered model.
To customize this model we recommend forking and editing this class.
This is a stochastic discrete-time discrete-state model with three
compartments: "S" for susceptible, "I" for infected, and "R" for
recovered individuals (the recovered individuals are implicit: ``R =
population - S - I``) with transitions ``S -> I -> R``.
:param int population: Total ``population = S + I + R``.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> I``
transitions. This allows false negative but no false positives.
"""
def __init__(self, population, recovery_time, data):
compartments = ("S", "I") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)
assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time
self.data = data
def global_model(self):
tau = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
rho = pyro.sample("rho", dist.Beta(10, 10))
return R0, tau, rho
def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "I": 1}
def transition(self, params, state, t):
R0, tau, rho = params
# Sample flows between compartments.
S2I = pyro.sample(
"S2I_{}".format(t),
infection_dist(
individual_rate=R0 / tau,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population,
),
)
I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau))
# Update compartments with flows.
state["S"] = state["S"] - S2I
state["I"] = state["I"] + S2I - I2R
# Condition on observations.
t_is_observed = isinstance(t, slice) or t < self.duration
pyro.sample(
"obs_{}".format(t),
binomial_dist(S2I, rho),
obs=self.data[t] if t_is_observed else None,
)
[docs]class SimpleSEIRModel(CompartmentalModel):
"""
Susceptible-Exposed-Infected-Recovered model.
To customize this model we recommend forking and editing this class.
This is a stochastic discrete-time discrete-state model with four
compartments: "S" for susceptible, "E" for exposed, "I" for infected,
and "R" for recovered individuals (the recovered individuals are
implicit: ``R = population - S - E - I``) with transitions
``S -> E -> I -> R``.
:param int population: Total ``population = S + E + I + R``.
:param float incubation_time: Mean incubation time (duration in state
``E``). Must be greater than 1.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> E``
transitions. This allows false negative but no false positives.
"""
def __init__(self, population, incubation_time, recovery_time, data):
compartments = ("S", "E", "I") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)
assert isinstance(incubation_time, float)
assert incubation_time > 1
self.incubation_time = incubation_time
assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time
self.data = data
def global_model(self):
tau_e = self.incubation_time
tau_i = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
rho = pyro.sample("rho", dist.Beta(10, 10))
return R0, tau_e, tau_i, rho
def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "E": 0, "I": 1}
def transition(self, params, state, t):
R0, tau_e, tau_i, rho = params
# Sample flows between compartments.
S2E = pyro.sample(
"S2E_{}".format(t),
infection_dist(
individual_rate=R0 / tau_i,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population,
),
)
E2I = pyro.sample("E2I_{}".format(t), binomial_dist(state["E"], 1 / tau_e))
I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau_i))
# Update compartments with flows.
state["S"] = state["S"] - S2E
state["E"] = state["E"] + S2E - E2I
state["I"] = state["I"] + E2I - I2R
# Condition on observations.
t_is_observed = isinstance(t, slice) or t < self.duration
pyro.sample(
"obs_{}".format(t),
binomial_dist(S2E, rho),
obs=self.data[t] if t_is_observed else None,
)
[docs]class SimpleSEIRDModel(CompartmentalModel):
"""
Susceptible-Exposed-Infected-Recovered-Dead model.
To customize this model we recommend forking and editing this class.
This is a stochastic discrete-time discrete-state model with four
compartments: "S" for susceptible, "E" for exposed, "I" for infected, "D"
for deceased individuals, and "R" for recovered individuals (the recovered
individuals are implicit: ``R = population - S - E - I - D``) with
transitions ``S -> E -> I -> R`` and ``I -> D``.
Because the transitions are not simple linear succession, this model
implements a custom :meth:`compute_flows()` method.
:param int population: Total ``population = S + E + I + R + D``.
:param float incubation_time: Mean incubation time (duration in state
``E``). Must be greater than 1.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param float mortality_rate: Portion of infections resulting in death.
Must be in the open interval ``(0, 1)``.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> E``
transitions. This allows false negative but no false positives.
"""
def __init__(
self, population, incubation_time, recovery_time, mortality_rate, data
):
compartments = ("S", "E", "I", "D") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)
assert isinstance(incubation_time, float)
assert incubation_time > 1
self.incubation_time = incubation_time
assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time
assert isinstance(mortality_rate, float)
assert 0 < mortality_rate < 1
self.mortality_rate = mortality_rate
self.data = data
def global_model(self):
tau_e = self.incubation_time
tau_i = self.recovery_time
mu = self.mortality_rate
R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
rho = pyro.sample("rho", dist.Beta(10, 10))
return R0, tau_e, tau_i, mu, rho
def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "E": 0, "I": 1, "D": 0}
def transition(self, params, state, t):
R0, tau_e, tau_i, mu, rho = params
# Sample flows between compartments.
S2E = pyro.sample(
"S2E_{}".format(t),
infection_dist(
individual_rate=R0 / tau_i,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population,
),
)
E2I = pyro.sample("E2I_{}".format(t), binomial_dist(state["E"], 1 / tau_e))
# Of the 1/tau_i expected recoveries-or-deaths, a portion mu die and
# the remaining recover. Alternatively we could model this with a
# Multinomial distribution I2_ and extract the two components I2D and
# I2R, however the Multinomial distribution does not currently
# implement overdispersion or moment matching.
I2D = pyro.sample("I2D_{}".format(t), binomial_dist(state["I"], mu / tau_i))
I2R = pyro.sample(
"I2R_{}".format(t), binomial_dist(state["I"] - I2D, 1 / tau_i)
)
# Update compartments with flows.
state["S"] = state["S"] - S2E
state["E"] = state["E"] + S2E - E2I
state["I"] = state["I"] + E2I - I2R - I2D
state["D"] = state["D"] + I2D
# Condition on observations.
t_is_observed = isinstance(t, slice) or t < self.duration
pyro.sample(
"obs_{}".format(t),
binomial_dist(S2E, rho),
obs=self.data[t] if t_is_observed else None,
)
def compute_flows(self, prev, curr, t):
S2E = prev["S"] - curr["S"] # S can only go to E.
I2D = curr["D"] - prev["D"] # D can only have come from I.
# We deduce the remaining flows by conservation of mass:
# curr - prev = inflows - outflows
E2I = prev["E"] - curr["E"] + S2E
I2R = prev["I"] - curr["I"] + E2I - I2D
return {
"S2E_{}".format(t): S2E,
"E2I_{}".format(t): E2I,
"I2D_{}".format(t): I2D,
"I2R_{}".format(t): I2R,
}
[docs]class OverdispersedSIRModel(CompartmentalModel):
"""
Generalizes :class:`SimpleSIRModel` with overdispersed distributions.
To customize this model we recommend forking and editing this class.
This adds a single global overdispersion parameter controlling
overdispersion of the transition and observation distributions. See
:func:`~pyro.contrib.epidemiology.distributions.binomial_dist` and
:func:`~pyro.contrib.epidemiology.distributions.beta_binomial_dist` for
distributional details. For prior work incorporating overdispersed
distributions see [1,2,3,4].
**References:**
[1] D. Champredon, M. Li, B. Bolker. J. Dushoff (2018)
"Two approaches to forecast Ebola synthetic epidemics"
https://www.sciencedirect.com/science/article/pii/S1755436517300233
[2] Carrie Reed et al. (2015)
"Estimating Influenza Disease Burden from Population-Based Surveillance
Data in the United States"
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4349859/
[3] A. Leonard, D. Weissman, B. Greenbaum, E. Ghedin, K. Koelle (2017)
"Transmission Bottleneck Size Estimation from Pathogen Deep-Sequencing
Data, with an Application to Human Influenza A Virus"
https://jvi.asm.org/content/jvi/91/14/e00171-17.full.pdf
[4] A. Miller, N. Foti, J. Lewnard, N. Jewell, C. Guestrin, E. Fox (2020)
"Mobility trends provide a leading indicator of changes in
SARS-CoV-2 transmission"
https://www.medrxiv.org/content/medrxiv/early/2020/05/11/2020.05.07.20094441.full.pdf
:param int population: Total ``population = S + I + R``.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> I``
transitions. This allows false negative but no false positives.
"""
def __init__(self, population, recovery_time, data):
compartments = ("S", "I") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)
assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time
self.data = data
def global_model(self):
tau = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
rho = pyro.sample("rho", dist.Beta(10, 10))
od = pyro.sample("od", dist.Beta(2, 6))
return R0, tau, rho, od
def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "I": 1}
def transition(self, params, state, t):
R0, tau, rho, od = params
# Sample flows between compartments.
S2I = pyro.sample(
"S2I_{}".format(t),
infection_dist(
individual_rate=R0 / tau,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population,
overdispersion=od,
),
)
I2R = pyro.sample(
"I2R_{}".format(t), binomial_dist(state["I"], 1 / tau, overdispersion=od)
)
# Update compartments with flows.
state["S"] = state["S"] - S2I
state["I"] = state["I"] + S2I - I2R
# Condition on observations.
t_is_observed = isinstance(t, slice) or t < self.duration
pyro.sample(
"obs_{}".format(t),
binomial_dist(S2I, rho, overdispersion=od),
obs=self.data[t] if t_is_observed else None,
)
[docs]class OverdispersedSEIRModel(CompartmentalModel):
"""
Generalizes :class:`SimpleSEIRModel` with overdispersed distributions.
To customize this model we recommend forking and editing this class.
This adds a single global overdispersion parameter controlling
overdispersion of the transition and observation distributions. See
:func:`~pyro.contrib.epidemiology.distributions.binomial_dist` and
:func:`~pyro.contrib.epidemiology.distributions.beta_binomial_dist` for
distributional details. For prior work incorporating overdispersed
distributions see [1,2,3,4].
**References:**
[1] D. Champredon, M. Li, B. Bolker. J. Dushoff (2018)
"Two approaches to forecast Ebola synthetic epidemics"
https://www.sciencedirect.com/science/article/pii/S1755436517300233
[2] Carrie Reed et al. (2015)
"Estimating Influenza Disease Burden from Population-Based Surveillance
Data in the United States"
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4349859/
[3] A. Leonard, D. Weissman, B. Greenbaum, E. Ghedin, K. Koelle (2017)
"Transmission Bottleneck Size Estimation from Pathogen Deep-Sequencing
Data, with an Application to Human Influenza A Virus"
https://jvi.asm.org/content/jvi/91/14/e00171-17.full.pdf
[4] A. Miller, N. Foti, J. Lewnard, N. Jewell, C. Guestrin, E. Fox (2020)
"Mobility trends provide a leading indicator of changes in
SARS-CoV-2 transmission"
https://www.medrxiv.org/content/medrxiv/early/2020/05/11/2020.05.07.20094441.full.pdf
:param int population: Total ``population = S + E + I + R``.
:param float incubation_time: Mean incubation time (duration in state
``E``). Must be greater than 1.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> E``
transitions. This allows false negative but no false positives.
"""
def __init__(self, population, incubation_time, recovery_time, data):
compartments = ("S", "E", "I") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)
assert isinstance(incubation_time, float)
assert incubation_time > 1
self.incubation_time = incubation_time
assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time
self.data = data
def global_model(self):
tau_e = self.incubation_time
tau_i = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
rho = pyro.sample("rho", dist.Beta(10, 10))
od = pyro.sample("od", dist.Beta(2, 6))
return R0, tau_e, tau_i, rho, od
def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "E": 0, "I": 1}
def transition(self, params, state, t):
R0, tau_e, tau_i, rho, od = params
# Sample flows between compartments.
S2E = pyro.sample(
"S2E_{}".format(t),
infection_dist(
individual_rate=R0 / tau_i,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population,
overdispersion=od,
),
)
E2I = pyro.sample(
"E2I_{}".format(t), binomial_dist(state["E"], 1 / tau_e, overdispersion=od)
)
I2R = pyro.sample(
"I2R_{}".format(t), binomial_dist(state["I"], 1 / tau_i, overdispersion=od)
)
# Update compartments with flows.
state["S"] = state["S"] - S2E
state["E"] = state["E"] + S2E - E2I
state["I"] = state["I"] + E2I - I2R
# Condition on observations.
t_is_observed = isinstance(t, slice) or t < self.duration
pyro.sample(
"obs_{}".format(t),
binomial_dist(S2E, rho, overdispersion=od),
obs=self.data[t] if t_is_observed else None,
)
[docs]class SuperspreadingSIRModel(CompartmentalModel):
"""
Generalizes :class:`SimpleSIRModel` by adding superspreading effects.
To customize this model we recommend forking and editing this class.
This model accounts for superspreading (overdispersed individual
reproductive number) by assuming each infected individual infects
BetaBinomial-many susceptible individuals, where the BetaBinomial
distribution acts as an overdispersed Binomial distribution, adapting the
more standard NegativeBinomial distribution that acts as an overdispersed
Poisson distribution [1,2] to the setting of finite populations. To
preserve Markov structure, we follow [2] and assume all infections by a
single individual occur on the single time step where that individual makes
an ``I -> R`` transition. That is, whereas the :class:`SimpleSIRModel`
assumes infected individuals infect `Binomial(S,R/tau)`-many susceptible
individuals during each infected time step (over `tau`-many steps on
average), this model assumes they infect `BetaBinomial(k,...,S)`-many
susceptible individuals but only on the final time step before recovering.
**References**
[1] J. O. Lloyd-Smith, S. J. Schreiber, P. E. Kopp, W. M. Getz (2005)
"Superspreading and the effect of individual variation on disease
emergence"
https://www.nature.com/articles/nature04153.pdf
[2] Lucy M. Li, Nicholas C. Grassly, Christophe Fraser (2017)
"Quantifying Transmission Heterogeneity Using Both Pathogen Phylogenies
and Incidence Time Series"
https://academic.oup.com/mbe/article/34/11/2982/3952784
:param int population: Total ``population = S + I + R``.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> I``
transitions. This allows false negative but no false positives.
"""
def __init__(self, population, recovery_time, data):
compartments = ("S", "I") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)
assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time
self.data = data
def global_model(self):
tau = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
k = pyro.sample("k", dist.Exponential(1.0))
rho = pyro.sample("rho", dist.Beta(10, 10))
return R0, k, tau, rho
def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "I": 1}
def transition(self, params, state, t):
R0, k, tau, rho = params
# Sample flows between compartments.
I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau))
S2I = pyro.sample(
"S2I_{}".format(t),
infection_dist(
individual_rate=R0,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population,
concentration=k,
),
)
# Update compartments with flows.
state["S"] = state["S"] - S2I
state["I"] = state["I"] + S2I - I2R
# Condition on observations.
t_is_observed = isinstance(t, slice) or t < self.duration
pyro.sample(
"obs_{}".format(t),
binomial_dist(S2I, rho),
obs=self.data[t] if t_is_observed else None,
)
[docs]class SuperspreadingSEIRModel(CompartmentalModel):
r"""
Generalizes :class:`SimpleSEIRModel` by adding superspreading effects.
To customize this model we recommend forking and editing this class.
This model accounts for superspreading (overdispersed individual
reproductive number) by assuming each infected individual infects
BetaBinomial-many susceptible individuals, where the BetaBinomial
distribution acts as an overdispersed Binomial distribution, adapting the
more standard NegativeBinomial distribution that acts as an overdispersed
Poisson distribution [1,2] to the setting of finite populations. To
preserve Markov structure, we follow [2] and assume all infections by a
single individual occur on the single time step where that individual makes
an ``I -> R`` transition. That is, whereas the :class:`SimpleSEIRModel`
assumes infected individuals infect `Binomial(S,R/tau)`-many susceptible
individuals during each infected time step (over `tau`-many steps on
average), this model assumes they infect `BetaBinomial(k,...,S)`-many
susceptible individuals but only on the final time step before recovering.
This model also adds an optional likelihood for observed phylogenetic data
in the form of coalescent times. These are provided as a pair
``(leaf_times, coal_times)`` of times at which genomes are sequenced and
lineages coalesce, respectively. We incorporate this data using the
:class:`~pyro.distributions.CoalescentRateLikelihood` with base coalescence
rate computed from the ``S`` and ``I`` populations. This likelihood is
independent across time and preserves the Markov propert needed for
inference.
**References**
[1] J. O. Lloyd-Smith, S. J. Schreiber, P. E. Kopp, W. M. Getz (2005)
"Superspreading and the effect of individual variation on disease
emergence"
https://www.nature.com/articles/nature04153.pdf
[2] Lucy M. Li, Nicholas C. Grassly, Christophe Fraser (2017)
"Quantifying Transmission Heterogeneity Using Both Pathogen Phylogenies
and Incidence Time Series"
https://academic.oup.com/mbe/article/34/11/2982/3952784
:param int population: Total ``population = S + E + I + R``.
:param float incubation_time: Mean incubation time (duration in state
``E``). Must be greater than 1.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> E``
transitions. This allows false negative but no false positives.
"""
def __init__(
self,
population,
incubation_time,
recovery_time,
data,
*,
leaf_times=None,
coal_times=None
):
compartments = ("S", "E", "I") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)
assert isinstance(incubation_time, float)
assert incubation_time > 1
self.incubation_time = incubation_time
assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time
self.data = data
assert (leaf_times is None) == (coal_times is None)
if leaf_times is None:
self.coal_likelihood = None
else:
self.coal_likelihood = dist.CoalescentRateLikelihood(
leaf_times, coal_times, duration
)
def global_model(self):
tau_e = self.incubation_time
tau_i = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
k = pyro.sample("k", dist.Exponential(1.0))
rho = pyro.sample("rho", dist.Beta(10, 10))
return R0, k, tau_e, tau_i, rho
def initialize(self, params):
# Start with a single exposure.
return {"S": self.population - 1, "E": 0, "I": 1}
def transition(self, params, state, t):
R0, k, tau_e, tau_i, rho = params
# Sample flows between compartments.
E2I = pyro.sample("E2I_{}".format(t), binomial_dist(state["E"], 1 / tau_e))
I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau_i))
S2E = pyro.sample(
"S2E_{}".format(t),
infection_dist(
individual_rate=R0,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population,
concentration=k,
),
)
# Condition on observations.
t_is_observed = isinstance(t, slice) or t < self.duration
pyro.sample(
"obs_{}".format(t),
binomial_dist(S2E, rho),
obs=self.data[t] if t_is_observed else None,
)
if self.coal_likelihood is not None:
R = R0 * state["S"] / self.population
coal_rate = R * (1.0 + 1.0 / k) / (tau_i * state["I"] + 1e-8)
pyro.factor(
"coalescent_{}".format(t),
(
self.coal_likelihood(coal_rate, t)
if t_is_observed
else torch.tensor(0.0)
),
)
# Update compartements with flows.
state["S"] = state["S"] - S2E
state["E"] = state["E"] + S2E - E2I
state["I"] = state["I"] + E2I - I2R
[docs]class HeterogeneousSIRModel(CompartmentalModel):
"""
Generalizes :class:`SimpleSIRModel` by allowing ``Rt`` and ``rho`` to vary
in time.
To customize this model we recommend forking and editing this class.
In this model, the response rate ``rho`` is piecewise constant with unknown
value over three pieces. The reproductive number ``Rt`` is a product of a
constant ``R0`` with a factor ``beta`` that drifts via Brownian motion in
log space. Both ``rho`` and ``Rt`` are available as time series.
:param int population: Total ``population = S + I + R``.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> I``
transitions. This allows false negative but no false positives.
"""
def __init__(self, population, recovery_time, data):
compartments = ("S", "I") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)
assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time
self.data = data
def global_model(self):
tau = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
# Let's consider a piecewise constant response rate, say low rate for
# two weeks, then intermediate rate as testing capacity increases, and
# finally high rate for a few months (as far into the future as we'd
# like to forecast). We don't know exactly what the rates are, but we
# can specify increasingly informative priors.
rho0 = pyro.sample("rho0", dist.Beta(2, 4))
rho1 = pyro.sample("rho1", dist.Beta(4, 4))
rho2 = pyro.sample("rho2", dist.Beta(8, 4))
# Later .transition() will index into this time series as rho[..., t].
rho = torch.cat(
[
rho0.unsqueeze(-1).expand(rho0.shape + (14,)),
rho1.unsqueeze(-1).expand(rho1.shape + (7,)),
rho2.unsqueeze(-1).expand(rho2.shape + (60,)),
],
dim=-1,
)
# We can also save the time series for output in self.samples.
pyro.deterministic("rho", rho, event_dim=1)
return R0, tau, rho
def initialize(self, params):
R0, tau, rho = params
# Start with a single infection.
# We also store the initial beta value in the state dict.
return {"S": self.population - 1, "I": 1, "beta": torch.tensor(1.0)}
def transition(self, params, state, t):
R0, tau, rho = params
# Sample heterogeneous variables.
# This assumes beta slowly drifts via Brownian motion in log space.
beta = pyro.sample(
"beta_{}".format(t), dist.LogNormal(state["beta"].log(), 0.1)
)
Rt = pyro.deterministic("Rt_{}".format(t), R0 * beta)
# Sample flows between compartments.
S2I = pyro.sample(
"S2I_{}".format(t),
infection_dist(
individual_rate=Rt / tau,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population,
),
)
I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau))
# Update compartments and heterogeneous variables.
state["S"] = state["S"] - S2I
state["I"] = state["I"] + S2I - I2R
state["beta"] = beta # We store the latest beta value in the state dict.
# Condition on observations.
# Note that, since rho may be batched over particles or samples, we
# need to index it via rho[..., t] rather than a simple rho[t].
t_is_observed = isinstance(t, slice) or t < self.duration
pyro.sample(
"obs_{}".format(t),
binomial_dist(S2I, rho[..., t]),
obs=self.data[t] if t_is_observed else None,
)
[docs]class SparseSIRModel(CompartmentalModel):
"""
Generalizes :class:`SimpleSIRModel` to allow sparsely observed infections.
To customize this model we recommend forking and editing this class.
This model allows observations of **cumulative** infections at uneven time
intervals. To preserve Markov structure (and hence tractable inference)
this model adds an auxiliary compartment ``O`` denoting the fully-observed
cumulative number of observations at each time point. At observed times
(when ``mask[t] == True``) ``O`` must exactly match the provided data;
between observed times ``O`` stochastically imputes the provided data.
This model demonstrates how to implement a custom :meth:`compute_flows`
method. A custom method is needed in this model because inhabitants of the
``S`` compartment can transition to both the ``I`` and ``O`` compartments,
allowing duplication.
:param int population: Total ``population = S + I + R``.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of **cumulative** observed infections.
Whenever ``mask[t] == True``, ``data[t]`` corresponds to an
observation; otherwise ``data[t]`` can be arbitrary, e.g. NAN.
:param iterable mask: Boolean time series denoting whether an observation
is made at each time step. Should satisfy ``len(mask) == len(data)``.
"""
def __init__(self, population, recovery_time, data, mask):
assert len(data) == len(mask)
duration = len(data)
compartments = ("S", "I", "O") # O is auxiliary, R is implicit.
super().__init__(compartments, duration, population)
assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time
self.data = data
self.mask = mask
def global_model(self):
tau = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
rho = pyro.sample("rho", dist.Beta(10, 10))
return R0, tau, rho
def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "I": 1, "O": 0}
def transition(self, params, state, t):
R0, tau, rho = params
# Sample flows between compartments.
S2I = pyro.sample(
"S2I_{}".format(t),
infection_dist(
individual_rate=R0 / tau,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population,
),
)
I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau))
S2O = pyro.sample("S2O_{}".format(t), binomial_dist(S2I, rho))
# Update compartments with flows.
state["S"] = state["S"] - S2I
state["I"] = state["I"] + S2I - I2R
state["O"] = state["O"] + S2O
# Condition on cumulative observations.
t_is_observed = isinstance(t, slice) or t < self.duration
mask_t = self.mask[t] if t_is_observed else False
data_t = self.data[t] if t_is_observed else None
pyro.sample(
"obs_{}".format(t),
# FIXME Delta is incompatible with relaxed inference.
dist.Delta(state["O"]).mask(mask_t),
obs=data_t,
)
def compute_flows(self, prev, curr, t):
# Reverse the flow computation.
S2I = prev["S"] - curr["S"]
I2R = prev["I"] - curr["I"] + S2I
S2O = curr["O"] - prev["O"]
return {
"S2I_{}".format(t): S2I,
"I2R_{}".format(t): I2R,
"S2O_{}".format(t): S2O,
}
[docs]class UnknownStartSIRModel(CompartmentalModel):
"""
Generalizes :class:`SimpleSIRModel` by allowing unknown date of first
infection.
To customize this model we recommend forking and editing this class.
This model demonstrates:
1. How to incorporate spontaneous infections from external sources;
2. How to incorporate time-varying piecewise ``rho`` by supporting
forecasting in :meth:`transition`.
3. How to override the :meth:`predict` method to compute extra
statistics.
:param int population: Total ``population = S + I + R``.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param int pre_obs_window: Number of time steps before beginning ``data``
where the initial infection may have occurred. Must be positive.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> I``
transitions. This allows false negative but no false positives.
"""
def __init__(self, population, recovery_time, pre_obs_window, data):
compartments = ("S", "I") # R is implicit.
duration = pre_obs_window + len(data)
super().__init__(compartments, duration, population)
assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time
assert isinstance(pre_obs_window, int) and pre_obs_window > 0
self.pre_obs_window = pre_obs_window
self.post_obs_window = len(data)
# We set a small time-constant external infecton rate such that on
# average there is a single external infection during the
# pre_obs_window. This allows unknown time of initial infection
# without introducing long-range coupling across time.
self.external_rate = 1 / pre_obs_window
# Prepend data with zeros.
if isinstance(data, list):
data = [0.0] * self.pre_obs_window + data
else:
data = pad(data, (self.pre_obs_window, 0), value=0.0)
self.data = data
def global_model(self):
tau = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
# Assume two different response rates: rho0 before any observations
# were made (in pre_obs_window), followed by a higher response rate rho1
# after observations were made (in post_obs_window).
rho0 = pyro.sample("rho0", dist.Beta(10, 10))
rho1 = pyro.sample("rho1", dist.Beta(10, 10))
# Whereas each of rho0,rho1 are scalars (possibly batched over samples),
# we construct a time series rho with an extra time dim on the right.
rho = torch.cat(
[
rho0.unsqueeze(-1).expand(rho0.shape + (self.pre_obs_window,)),
rho1.unsqueeze(-1).expand(rho1.shape + (self.post_obs_window,)),
],
dim=-1,
)
# Model external infections as an infectious pseudo-individual added
# to num_infectious when sampling S2I below.
X = self.external_rate * tau / R0
return R0, X, tau, rho
def initialize(self, params):
# Start with no internal infections.
return {"S": self.population, "I": 0}
def transition(self, params, state, t):
R0, X, tau, rho = params
# Sample flows between compartments.
S2I = pyro.sample(
"S2I_{}".format(t),
infection_dist(
individual_rate=R0 / tau,
num_susceptible=state["S"],
num_infectious=state["I"] + X,
population=self.population,
),
)
I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau))
# Update compartments with flows.
state["S"] = state["S"] - S2I
state["I"] = state["I"] + S2I - I2R
# In .transition() t will always be an integer but may lie outside
# of [0,self.duration) when forecasting.
t_is_observed = isinstance(t, slice) or t < self.duration
rho_t = rho[..., t] if t_is_observed else rho[..., -1]
data_t = self.data[t] if t_is_observed else None
# Condition on observations.
pyro.sample("obs_{}".format(t), binomial_dist(S2I, rho_t), obs=data_t)
def predict(self, forecast=0):
"""
Augments
:meth:`~pyro.contrib.epidemiology.compartmental.Compartmental.predict`
with samples of ``first_infection`` i.e. the first time index at which
the infection ``I`` becomes nonzero. Note this is measured from the
beginning of ``pre_obs_window``, not the beginning of data.
:param int forecast: The number of time steps to forecast forward.
:returns: A dictionary mapping sample site name (or compartment name)
to a tensor whose first dimension corresponds to sample batching.
:rtype: dict
"""
samples = super().predict(forecast)
# Extract the time index of the first infection (samples["I"] > 0)
# for each sample trajectory in the samples["I"] tensor.
samples["first_infection"] = samples["I"].cumsum(-1).eq(0).sum(-1)
return samples
[docs]class RegionalSIRModel(CompartmentalModel):
r"""
Generalizes :class:`SimpleSIRModel` to simultaneously model multiple
regions with weak coupling across regions.
To customize this model we recommend forking and editing this class.
Regions are coupled by a ``coupling`` matrix with entries in ``[0,1]``.
The all ones matrix is equivalent to a single region. The identity matrix
is equivalent to a set of independent regions. This need not be symmetric,
but symmetric matrices are probably more physically plausible. The expected
number of new infections each time step ``S2I`` is Binomial distributed
with mean::
E[S2I] = S (1 - (1 - R0 / (population @ coupling)) ** (I @ coupling))
≈ R0 S (I @ coupling) / (population @ coupling) # for small I
Thus in a nearly entirely susceptible population, a single infected
individual infects approximately ``R0`` new individuals on average,
independent of ``coupling``.
This model demonstrates:
1. How to create a regional model with a ``population`` vector.
2. How to model both homogeneous parameters (here ``R0``) and
heterogeneous parameters with hierarchical structure (here ``rho``)
using ``self.region_plate``.
3. How to approximately couple regions in :meth:`transition` using
``state["I_approx"]``.
:param torch.Tensor population: Tensor of per-region populations, defining
``population = S + I + R``.
:param torch.Tensor coupling: Pairwise coupling matrix. Entries should be
in ``[0,1]``.
:param float recovery_time: Mean recovery time (duration in state ``I``).
Must be greater than 1.
:param iterable data: Time x Region sized tensor of new observed
infections. Each time step is vector of Binomials distributed between
0 and the number of ``S -> I`` transitions. This allows false negative
but no false positives.
"""
def __init__(self, population, coupling, recovery_time, data):
duration = len(data)
(num_regions,) = population.shape
assert coupling.shape == (num_regions, num_regions)
assert (0 <= coupling).all()
assert (coupling <= 1).all()
assert isinstance(recovery_time, float)
assert recovery_time > 1
if isinstance(data, torch.Tensor):
# Data tensors should be oriented as (time, region).
assert data.shape == (duration, num_regions)
compartments = ("S", "I") # R is implicit.
# We create a regional model by passing a vector of populations.
super().__init__(compartments, duration, population, approximate=("I",))
self.coupling = coupling
self.recovery_time = recovery_time
self.data = data
def global_model(self):
# Assume recovery time is a known constant.
tau = self.recovery_time
# Assume reproductive number is unknown but homogeneous.
R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
# Assume response rate is heterogeneous and model it with a
# hierarchical Gamma-Beta prior.
rho_c1 = pyro.sample("rho_c1", dist.Gamma(10, 1))
rho_c0 = pyro.sample("rho_c0", dist.Gamma(10, 1))
with self.region_plate:
rho = pyro.sample("rho", dist.Beta(rho_c1, rho_c0))
return R0, tau, rho
def initialize(self, params):
# Start with a single infection in region 0.
I = torch.zeros_like(self.population)
I[0] += 1
S = self.population - I
return {"S": S, "I": I}
def transition(self, params, state, t):
R0, tau, rho = params
# Account for infections from all regions. This uses approximate (point
# estimate) counts I_approx for infection from other regions, but uses
# the exact (enumerated) count I for infections from one's own region.
I_coupled = state["I_approx"] @ self.coupling
I_coupled = I_coupled + (state["I"] - state["I_approx"]) * self.coupling.diag()
I_coupled = I_coupled.clamp(min=0) # In case I_approx is negative.
pop_coupled = self.population @ self.coupling
with self.region_plate:
# Sample flows between compartments.
S2I = pyro.sample(
"S2I_{}".format(t),
infection_dist(
individual_rate=R0 / tau,
num_susceptible=state["S"],
num_infectious=I_coupled,
population=pop_coupled,
),
)
I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau))
# Update compartments with flows.
state["S"] = state["S"] - S2I
state["I"] = state["I"] + S2I - I2R
# Condition on observations.
t_is_observed = isinstance(t, slice) or t < self.duration
pyro.sample(
"obs_{}".format(t),
binomial_dist(S2I, rho),
obs=self.data[t] if t_is_observed else None,
)
[docs]class HeterogeneousRegionalSIRModel(CompartmentalModel):
"""
Generalizes :class:`RegionalSIRModel` by allowing ``Rt`` and ``rho`` to vary
in time.
To customize this model we recommend forking and editing this class.
In this model, the response rate ``rho`` varies across time and region,
whereas the reproductive number ``Rt`` varies in time but is shared among
regions. Both parameters drift according to transformed Brownian motion
with learned drift rate.
This model demonstrates how to model hierarchical latent time series,
other than compartmental variables.
:param torch.Tensor population: Tensor of per-region populations, defining
``population = S + I + R``.
:param torch.Tensor coupling: Pairwise coupling matrix. Entries should be
in ``[0,1]``.
:param float recovery_time: Mean recovery time (duration in state ``I``).
Must be greater than 1.
:param iterable data: Time x Region sized tensor of new observed
infections. Each time step is vector of Binomials distributed between
0 and the number of ``S -> I`` transitions. This allows false negative
but no false positives.
"""
def __init__(self, population, coupling, recovery_time, data):
duration = len(data)
(num_regions,) = population.shape
assert coupling.shape == (num_regions, num_regions)
assert (0 <= coupling).all()
assert (coupling <= 1).all()
assert isinstance(recovery_time, float)
assert recovery_time > 1
if isinstance(data, torch.Tensor):
# Data tensors should be oriented as (time, region).
assert data.shape == (duration, num_regions)
compartments = ("S", "I") # R is implicit.
# We create a regional model by passing a vector of populations.
super().__init__(compartments, duration, population, approximate=("I",))
self.coupling = coupling
self.recovery_time = recovery_time
self.data = data
def global_model(self):
tau = self.recovery_time
# Assume reproductive number is heterogeneous but shared among regions.
R0 = pyro.sample("R0", dist.LogNormal(0.0, 1.0))
R_drift = pyro.sample("R_drift", dist.LogNormal(-3.0, 1.0))
# Assume response rate is heterogeneous in time and region.
with self.region_plate:
rho0 = pyro.sample("rho0", dist.Beta(4, 4))
rho_drift = pyro.sample("rho_drift", dist.LogNormal(-3.0, 1.0))
return tau, R0, R_drift, rho0, rho_drift
def initialize(self, params):
# Start with a single infection in region 0.
I = torch.zeros_like(self.population)
I[0] += 1
S = self.population - I
return {
"S": S,
"I": I,
"R_factor": torch.tensor(1.0),
"rho_shift": torch.tensor(0.0),
}
def transition(self, params, state, t):
tau, R0, R_drift, rho0, rho_drift = params
# Account for infections from all regions. This uses approximate (point
# estimate) counts I_approx for infection from other regions, but uses
# the exact (enumerated) count I for infections from one's own region.
I_coupled = state["I_approx"] @ self.coupling
I_coupled = I_coupled + (state["I"] - state["I_approx"]) * self.coupling.diag()
I_coupled = I_coupled.clamp(min=0) # In case I_approx is negative.
pop_coupled = self.population @ self.coupling
# Sample region-global time-heterogeneous variables.
R_factor = pyro.sample(
"R_factor_{}".format(t), dist.LogNormal(state["R_factor"].log(), R_drift)
)
Rt = pyro.deterministic("Rt_{}".format(t), R0 * R_factor)
with self.region_plate:
# Sample region-local time-heterogeneous variables.
rho_shift = pyro.sample(
"rho_shift_{}".format(t), dist.Normal(state["rho_shift"], rho_drift)
)
rho = pyro.deterministic(
"rho_{}".format(t), (rho0.log() - (-rho0).log1p() + rho_shift).sigmoid()
)
# Sample flows between compartments.
S2I = pyro.sample(
"S2I_{}".format(t),
infection_dist(
individual_rate=Rt / tau,
num_susceptible=state["S"],
num_infectious=I_coupled,
population=pop_coupled,
),
)
I2R = pyro.sample("I2R_{}".format(t), binomial_dist(state["I"], 1 / tau))
# Update compartments and heterogeneous variables.
state["S"] = state["S"] - S2I
state["I"] = state["I"] + S2I - I2R
state["R_factor"] = R_factor
state["rho_shift"] = rho_shift
# Condition on observations.
t_is_observed = isinstance(t, slice) or t < self.duration
pyro.sample(
"obs_{}".format(t),
binomial_dist(S2I, rho),
obs=self.data[t] if t_is_observed else None,
)
# Create sphinx documentation.
__all__ = []
for _name, _Model in list(locals().items()):
if isinstance(_Model, type) and issubclass(_Model, CompartmentalModel):
if _Model is not CompartmentalModel:
__all__.append(_name)
__all__.sort(
key=lambda name, vals=locals(): vals[name].__init__.__code__.co_firstlineno
)
__doc__ = "\n\n".join(
[
"""
{}
----------------------------------------------------------------
.. autoclass:: pyro.contrib.epidemiology.models.{}
""".format(
re.sub("([A-Z][a-z]+)", r"\1 ", _name[:-5]), _name
)
for _name in __all__
]
)