Source code for pyro.contrib.epidemiology.compartmental

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import functools
import logging
import operator
import re
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import ExitStack, contextmanager
from functools import reduce
from timeit import default_timer

import torch
from torch.distributions import biject_to, constraints
from torch.distributions.utils import lazy_property

import pyro.distributions as dist
import pyro.distributions.hmm
import pyro.poutine as poutine
from pyro.distributions.transforms import HaarTransform
from pyro.infer import (
    MCMC,
    NUTS,
    SVI,
    JitTrace_ELBO,
    SMCFilter,
    Trace_ELBO,
    infer_discrete,
)
from pyro.infer.autoguide import (
    AutoLowRankMultivariateNormal,
    AutoMultivariateNormal,
    AutoNormal,
    init_to_generated,
    init_to_value,
)
from pyro.infer.mcmc import ArrowheadMassMatrix
from pyro.infer.reparam import HaarReparam, SplitReparam
from pyro.infer.smcfilter import SMCFailed
from pyro.infer.util import is_validation_enabled
from pyro.optim import ClippedAdam
from pyro.poutine.util import site_is_factor, site_is_subsample
from pyro.util import warn_if_nan

from .distributions import (
    set_approx_log_prob_tol,
    set_approx_sample_thresh,
    set_relaxed_distributions,
)
from .util import align_samples, cat2, clamp, quantize, quantize_enumerate

logger = logging.getLogger(__name__)


def _require_double_precision():
    if torch.get_default_dtype() != torch.float64:
        warnings.warn(
            "CompartmentalModel is unstable for dtypes less than torch.float64; "
            "try torch.set_default_dtype(torch.float64)",
            RuntimeWarning,
        )


@contextmanager
def _disallow_latent_variables(section_name):
    if not is_validation_enabled():
        yield
        return

    with poutine.trace() as tr:
        yield
    for name, site in tr.trace.nodes.items():
        if site["type"] == "sample" and not site["is_observed"]:
            raise NotImplementedError(
                "{} contained latent variable {}".format(section_name, name)
            )


[docs]class CompartmentalModel(ABC): """ Abstract base class for discrete-time discrete-value stochastic compartmental models. Derived classes must implement methods :meth:`initialize` and :meth:`transition`. Derived classes may optionally implement :meth:`global_model`, :meth:`compute_flows`, and :meth:`heuristic`. Example usage:: # First implement a concrete derived class. class MyModel(CompartmentalModel): def __init__(self, ...): ... def global_model(self): ... def initialize(self, params): ... def transition(self, params, state, t): ... # Run inference to fit the model to data. model = MyModel(...) model.fit_svi(num_samples=100) # or .fit_mcmc(...) R0 = model.samples["R0"] # An example parameter. print("R0 = {:0.3g} \u00B1 {:0.3g}".format(R0.mean(), R0.std())) # Predict latent variables. samples = model.predict() # Forecast forward. samples = model.predict(forecast=30) # You can assess future interventions (applied after ``duration``) by # storing them as attributes that are read by your derived methods. model.my_intervention = False samples1 = model.predict(forecast=30) model.my_intervention = True samples2 = model.predict(forecast=30) effect = samples2["my_result"].mean() - samples1["my_result"].mean() print("average effect = {:0.3g}".format(effect)) An example workflow is to use cheaper approximate inference while finding good model structure and priors, then move to more accurate but more expensive inference once the model is plausible. 1. Start with ``.fit_svi(guide_rank=0, num_steps=2000)`` for cheap inference while you search for a good model. 2. Additionally infer long-range correlations by moving to a low-rank multivariate normal guide via ``.fit_svi(guide_rank=None, num_steps=5000)``. 3. Optionally additionally infer non-Gaussian posterior by moving to the more expensive (but still approximate via moment matching) ``.fit_mcmc(num_quant_bins=1, num_samples=10000, num_chains=2)``. 4. Optionally improve fit around small counts by moving the the more expensive enumeration-based algorithm ``.fit_mcmc(num_quant_bins=4, num_samples=10000, num_chains=2)`` (GPU recommended). :ivar dict samples: Dictionary of posterior samples. :param list compartments: A list of strings of compartment names. :param int duration: The number of discrete time steps in this model. :param population: Either the total population of a single-region model or a tensor of each region's population in a regional model. :type population: int or torch.Tensor :param tuple approximate: Names of compartments for which pointwise approximations should be provided in :meth:`transition`, e.g. if you specify ``approximate=("I")`` then the ``state["I_approx"]`` will be a continuous-valued non-enumerated point estimate of ``state["I"]``. Approximations are useful to reduce computational cost. Approximations are continuous-valued with support ``(-0.5, population + 0.5)``. """ def __init__(self, compartments, duration, population, *, approximate=()): super().__init__() assert isinstance(duration, int) assert duration >= 1 self.duration = duration if isinstance(population, torch.Tensor): assert population.dim() == 1 assert (population >= 1).all() self.is_regional = True self.max_plate_nesting = 2 # [time, region] else: assert isinstance(population, int) assert population >= 2 self.is_regional = False self.max_plate_nesting = 1 # [time] self.population = population compartments = tuple(compartments) assert all(isinstance(name, str) for name in compartments) assert len(compartments) == len(set(compartments)) self.compartments = compartments assert isinstance(approximate, tuple) assert all(name in compartments for name in approximate) self.approximate = approximate # Inference state. self.samples = {} self._clear_plates() @property def time_plate(self): """ A ``pyro.plate`` for the time dimension. """ if self._time_plate is None: self._time_plate = pyro.plate( "time", self.duration, dim=-2 if self.is_regional else -1 ) return self._time_plate @property def region_plate(self): """ Either a ``pyro.plate`` or a trivial ``ExitStack`` depending on whether this model ``.is_regional``. """ if self._region_plate is None: if self.is_regional: self._region_plate = pyro.plate("region", len(self.population), dim=-1) else: self._region_plate = ExitStack() # Trivial context manager. return self._region_plate def _clear_plates(self): self._time_plate = None self._region_plate = None @lazy_property def full_mass(self): """ A list of a single tuple of the names of global random variables. """ with torch.no_grad(), poutine.block(), poutine.trace() as tr: self.global_model() return [ tuple( name for name, site in tr.trace.iter_stochastic_nodes() if not site_is_subsample(site) ) ] @lazy_property def series(self): """ A frozenset of names of sample sites that are sampled each time step. """ # Trace a simple invocation of .transition(). with torch.no_grad(), poutine.block(): params = self.global_model() prev = self.initialize(params) for name in self.approximate: prev[name + "_approx"] = prev[name] curr = prev.copy() with poutine.trace() as tr: self.transition(params, curr, 0) return frozenset( re.match("(.*)_0", name).group(1) for name, site in tr.trace.nodes.items() if site["type"] == "sample" if not site_is_subsample(site) ) # Overridable attributes and methods ########################################
[docs] def global_model(self): """ Samples and returns any global parameters. :returns: An arbitrary object of parameters (e.g. ``None`` or a tuple). """ return None
# TODO Allow stochastic initialization.
[docs] @abstractmethod def initialize(self, params): """ Returns initial counts in each compartment. :param params: The global params returned by :meth:`global_model`. :returns: A dict mapping compartment name to initial value. :rtype: dict """ raise NotImplementedError
[docs] @abstractmethod def transition(self, params, state, t): """ Forward generative process for dynamics. This inputs a current ``state`` and stochastically updates that state in-place. Note that this method is called under multiple different interpretations, including batched and vectorized interpretations. During :meth:`generate` this is called to generate a single sample. During :meth:`heuristic` this is called to generate a batch of samples for SMC. During :meth:`fit_mcmc` this is called both in vectorized form (vectorizing over time) and in sequential form (for a single time step); both forms enumerate over discrete latent variables. During :meth:`predict` this is called to forecast a batch of samples, conditioned on posterior samples for the time interval ``[0:duration]``. :param params: The global params returned by :meth:`global_model`. :param dict state: A dictionary mapping compartment name to current tensor value. This should be updated in-place. :param t: A time-like index. During inference ``t`` may be either a slice (for vectorized inference) or an integer time index. During prediction ``t`` will be integer time index. :type t: int or slice """ raise NotImplementedError
[docs] def finalize(self, params, prev, curr): """ Optional method for likelihoods that depend on entire time series. This should be used only for non-factorizable likelihoods that couple states across time. Factorizable likelihoods should instead be added to the :meth:`transition` method, thereby enabling their use in :meth:`heuristic` initialization. Since this method is called only after the last time step, it is not used in :meth:`heuristic` initialization. .. warning:: This currently does not support latent variables. :param params: The global params returned by :meth:`global_model`. :param dict prev: :param dict curr: Dictionaries mapping compartment name to tensor of entire time series. These two parameters are offset by 1 step, thereby making it easy to compute time series of fluxes. For quantized inference, this uses the approximate point estimates, so users must request any needed time series in :meth:`__init__`, e.g. by calling ``super().__init__(..., approximate=("I", "E"))`` if likelihood depends on the ``I`` and ``E`` time series. """ pass
[docs] def compute_flows(self, prev, curr, t): """ Computes flows between compartments, given compartment populations before and after time step t. The default implementation assumes sequential flows terminating in an implicit compartment named "R". For example if:: compartment_names = ("S", "E", "I") the default implementation computes at time step ``t = 9``:: flows["S2E_9"] = prev["S"] - curr["S"] flows["E2I_9"] = prev["E"] - curr["E"] + flows["S2E_9"] flows["I2R_9"] = prev["I"] - curr["I"] + flows["E2I_9"] For more complex flows (non-sequential, branching, looping, duplicating, etc.), users may override this method. :param dict state: A dictionary mapping compartment name to current tensor value. This should be updated in-place. :param t: A time-like index. During inference ``t`` may be either a slice (for vectorized inference) or an integer time index. During prediction ``t`` will be integer time index. :type t: int or slice :returns: A dict mapping flow name to tensor value. :rtype: dict """ flows = {} flow = 0 for source, destin in zip(self.compartments, self.compartments[1:] + ("R",)): flow = prev[source] - curr[source] + flow flows["{}2{}_{}".format(source, destin, t)] = flow return flows
# Inference interface ########################################
[docs] @torch.no_grad() @set_approx_sample_thresh(1000) def generate(self, fixed={}): """ Generate data from the prior. :pram dict fixed: A dictionary of parameters on which to condition. These must be top-level parentless nodes, i.e. have no upstream stochastic dependencies. :returns: A dictionary mapping sample site name to sampled value. :rtype: dict """ fixed = {k: torch.as_tensor(v) for k, v in fixed.items()} model = self._generative_model model = poutine.condition(model, fixed) trace = poutine.trace(model).get_trace() samples = OrderedDict( (name, site["value"]) for name, site in trace.nodes.items() if site["type"] == "sample" ) self._concat_series(samples, trace) return samples
[docs] def fit_svi( self, *, num_samples=100, num_steps=2000, num_particles=32, learning_rate=0.1, learning_rate_decay=0.01, betas=(0.8, 0.99), haar=True, init_scale=0.01, guide_rank=0, jit=False, log_every=200, **options, ): """ Runs stochastic variational inference to generate posterior samples. This runs :class:`~pyro.infer.svi.SVI`, setting the ``.samples`` attribute on completion. This approximate inference method is useful for quickly iterating on probabilistic models. :param int num_samples: Number of posterior samples to draw from the trained guide. Defaults to 100. :param int num_steps: Number of :class:`~pyro.infer.svi.SVI` steps. :param int num_particles: Number of :class:`~pyro.infer.svi.SVI` particles per step. :param int learning_rate: Learning rate for the :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. :param int learning_rate_decay: Learning rate for the :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. Note this is decay over the entire schedule, not per-step decay. :param tuple betas: Momentum parameters for the :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. :param bool haar: Whether to use a Haar wavelet reparameterizer. :param int guide_rank: Rank of the auto normal guide. If zero (default) use an :class:`~pyro.infer.autoguide.AutoNormal` guide. If a positive integer or None, use an :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide. If the string "full", use an :class:`~pyro.infer.autoguide.AutoMultivariateNormal` guide. These latter two require more ``num_steps`` to fit. :param float init_scale: Initial scale of the :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide. :param bool jit: Whether to use a jit compiled ELBO. :param int log_every: How often to log svi losses. :param int heuristic_num_particles: Passed to :meth:`heuristic` as ``num_particles``. Defaults to 1024. :returns: Time series of SVI losses (useful to diagnose convergence). :rtype: list """ # Save configuration for .predict(). self.relaxed = True self.num_quant_bins = 1 # Setup Haar wavelet transform. if haar: time_dim = -2 if self.is_regional else -1 dims = {"auxiliary": time_dim} supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)} for name, (fn, is_regional) in self._non_compartmental.items(): dims[name] = time_dim - fn.event_dim supports[name] = fn.support haar = _HaarSplitReparam(0, self.duration, dims, supports) # Heuristically initialize to feasible latents. heuristic_options = { k.replace("heuristic_", ""): options.pop(k) for k in list(options) if k.startswith("heuristic_") } assert not options, "unrecognized options: {}".format(", ".join(options)) init_strategy = self._heuristic(haar, **heuristic_options) # Configure variational inference. logger.info("Running inference...") model = self._relaxed_model if haar: model = haar.reparam(model) if guide_rank == 0: guide = AutoNormal(model, init_loc_fn=init_strategy, init_scale=init_scale) elif guide_rank == "full": guide = AutoMultivariateNormal( model, init_loc_fn=init_strategy, init_scale=init_scale ) elif guide_rank is None or isinstance(guide_rank, int): guide = AutoLowRankMultivariateNormal( model, init_loc_fn=init_strategy, init_scale=init_scale, rank=guide_rank ) else: raise ValueError("Invalid guide_rank: {}".format(guide_rank)) Elbo = JitTrace_ELBO if jit else Trace_ELBO elbo = Elbo( max_plate_nesting=self.max_plate_nesting, num_particles=num_particles, vectorize_particles=True, ignore_jit_warnings=True, ) optim = ClippedAdam( { "lr": learning_rate, "betas": betas, "lrd": learning_rate_decay ** (1 / num_steps), } ) svi = SVI(model, guide, optim, elbo) # Run inference. start_time = default_timer() losses = [] for step in range(1 + num_steps): loss = svi.step() / self.duration if step % log_every == 0: logger.info("step {} loss = {:0.4g}".format(step, loss)) losses.append(loss) elapsed = default_timer() - start_time logger.info( "SVI took {:0.1f} seconds, {:0.1f} step/sec".format( elapsed, (1 + num_steps) / elapsed ) ) # Draw posterior samples. with torch.no_grad(): particle_plate = pyro.plate( "particles", num_samples, dim=-1 - self.max_plate_nesting ) guide_trace = poutine.trace(particle_plate(guide)).get_trace() model_trace = poutine.trace( poutine.replay(particle_plate(model), guide_trace) ).get_trace() self.samples = { name: site["value"] for name, site in model_trace.nodes.items() if site["type"] == "sample" if not site["is_observed"] if not site_is_subsample(site) } if haar: haar.aux_to_user(self.samples) assert all(v.size(0) == num_samples for v in self.samples.values()), { k: tuple(v.shape) for k, v in self.samples.items() } return losses
[docs] @set_approx_log_prob_tol(0.1) def fit_mcmc(self, **options): r""" Runs NUTS inference to generate posterior samples. This uses the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel to run :class:`~pyro.infer.mcmc.api.MCMC`, setting the ``.samples`` attribute on completion. This uses an asymptotically exact enumeration-based model when ``num_quant_bins > 1``, and a cheaper moment-matched approximate model when ``num_quant_bins == 1``. :param \*\*options: Options passed to :class:`~pyro.infer.mcmc.api.MCMC`. The remaining options are pulled out and have special meaning. :param int num_samples: Number of posterior samples to draw via mcmc. Defaults to 100. :param int max_tree_depth: (Default 5). Max tree depth of the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. :param full_mass: Specification of mass matrix of the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. Defaults to full mass over global random variables. :param bool arrowhead_mass: Whether to treat ``full_mass`` as the head of an arrowhead matrix versus simply as a block. Defaults to False. :param int num_quant_bins: If greater than 1, use asymptotically exact inference via local enumeration over this many quantization bins. If equal to 1, use continuous-valued relaxed approximate inference. Note that computational cost is exponential in `num_quant_bins`. Defaults to 1 for relaxed inference. :param bool haar: Whether to use a Haar wavelet reparameterizer. Defaults to True. :param int haar_full_mass: Number of low frequency Haar components to include in the full mass matrix. If ``haar=False`` then this is ignored. Defaults to 10. :param int heuristic_num_particles: Passed to :meth:`heuristic` as ``num_particles``. Defaults to 1024. :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``. :rtype: ~pyro.infer.mcmc.api.MCMC """ _require_double_precision() # Parse options, saving some for use in .predict(). num_samples = options.setdefault("num_samples", 100) num_chains = options.setdefault("num_chains", 1) self.num_quant_bins = options.pop("num_quant_bins", 1) assert isinstance(self.num_quant_bins, int) assert self.num_quant_bins >= 1 self.relaxed = self.num_quant_bins == 1 # Setup Haar wavelet transform. haar = options.pop("haar", False) haar_full_mass = options.pop("haar_full_mass", 10) full_mass = options.pop("full_mass", self.full_mass) assert isinstance(haar, bool) assert isinstance(haar_full_mass, int) and haar_full_mass >= 0 assert isinstance(full_mass, (bool, list)) haar_full_mass = min(haar_full_mass, self.duration) if not haar: haar_full_mass = 0 if full_mass is True: haar_full_mass = 0 # No need to split. elif haar_full_mass >= self.duration: full_mass = True # Effectively full mass. haar_full_mass = 0 if haar: time_dim = -2 if self.is_regional else -1 dims = {"auxiliary": time_dim} supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)} for name, (fn, is_regional) in self._non_compartmental.items(): dims[name] = time_dim - fn.event_dim supports[name] = fn.support haar = _HaarSplitReparam(haar_full_mass, self.duration, dims, supports) if haar_full_mass: assert full_mass and isinstance(full_mass, list) full_mass = full_mass[:] full_mass[0] += tuple(name + "_haar_split_0" for name in sorted(dims)) # Heuristically initialize to feasible latents. heuristic_options = { k.replace("heuristic_", ""): options.pop(k) for k in list(options) if k.startswith("heuristic_") } init_strategy = init_to_generated( generate=functools.partial(self._heuristic, haar, **heuristic_options) ) # Configure a kernel. logger.info("Running inference...") model = self._relaxed_model if self.relaxed else self._quantized_model if haar: model = haar.reparam(model) kernel = NUTS( model, full_mass=full_mass, init_strategy=init_strategy, max_plate_nesting=self.max_plate_nesting, jit_compile=options.pop("jit_compile", False), jit_options=options.pop("jit_options", None), ignore_jit_warnings=options.pop("ignore_jit_warnings", True), target_accept_prob=options.pop("target_accept_prob", 0.8), max_tree_depth=options.pop("max_tree_depth", 5), ) if options.pop("arrowhead_mass", False): kernel.mass_matrix_adapter = ArrowheadMassMatrix() # Run mcmc. options.setdefault("disable_validation", None) mcmc = MCMC(kernel, **options) mcmc.run() self.samples = mcmc.get_samples() if haar: haar.aux_to_user(self.samples) # Unsqueeze samples to align particle dim for use in poutine.condition. # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples(). model = self._relaxed_model if self.relaxed else self._quantized_model self.samples = align_samples( self.samples, model, particle_dim=-1 - self.max_plate_nesting ) assert all( v.size(0) == num_samples * num_chains for v in self.samples.values() ), {k: tuple(v.shape) for k, v in self.samples.items()} return mcmc # E.g. so user can run mcmc.summary().
[docs] @torch.no_grad() @set_approx_log_prob_tol(0.1) @set_approx_sample_thresh(10000) def predict(self, forecast=0): """ Predict latent variables and optionally forecast forward. This may be run only after :meth:`fit_mcmc` and draws the same ``num_samples`` as passed to :meth:`fit_mcmc`. :param int forecast: The number of time steps to forecast forward. :returns: A dictionary mapping sample site name (or compartment name) to a tensor whose first dimension corresponds to sample batching. :rtype: dict """ if self.num_quant_bins > 1: _require_double_precision() if not self.samples: raise RuntimeError("Missing samples, try running .fit_mcmc() first") samples = self.samples num_samples = len(next(iter(samples.values()))) particle_plate = pyro.plate( "particles", num_samples, dim=-1 - self.max_plate_nesting ) # Sample discrete auxiliary variables conditioned on the continuous # variables sampled by _quantized_model. This samples only time steps # [0:duration]. Here infer_discrete runs a forward-filter # backward-sample algorithm. logger.info( "Predicting latent variables for {} time steps...".format(self.duration) ) model = self._sequential_model model = poutine.condition(model, samples) model = particle_plate(model) if not self.relaxed: model = infer_discrete( model, first_available_dim=-2 - self.max_plate_nesting ) trace = poutine.trace(model).get_trace() samples = OrderedDict( (name, site["value"].expand(site["fn"].shape())) for name, site in trace.nodes.items() if site["type"] == "sample" if not site_is_subsample(site) if not site_is_factor(site) ) assert all(v.size(0) == num_samples for v in samples.values()), { k: tuple(v.shape) for k, v in samples.items() } # Optionally forecast with the forward _generative_model. This samples # time steps [duration:duration+forecast]. if forecast: logger.info("Forecasting {} steps ahead...".format(forecast)) model = self._generative_model model = poutine.condition(model, samples) model = particle_plate(model) trace = poutine.trace(model).get_trace(forecast) samples = OrderedDict( (name, site["value"]) for name, site in trace.nodes.items() if site["type"] == "sample" if not site_is_subsample(site) if not site_is_factor(site) ) self._concat_series(samples, trace, forecast) assert all(v.size(0) == num_samples for v in samples.values()), { k: tuple(v.shape) for k, v in samples.items() } return samples
[docs] @torch.no_grad() @set_approx_log_prob_tol(0.1) @set_approx_sample_thresh(100) # This is robust to gross approximation. def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10): """ Finds an initial feasible guess of all latent variables, consistent with observed data. This is needed because not all hypotheses are feasible and HMC needs to start at a feasible solution to progress. The default implementation attempts to find a feasible state using :class:`~pyro.infer.smcfilter.SMCFilter` with proprosals from the prior. However this method may be overridden in cases where SMC performs poorly e.g. in high-dimensional models. :param int num_particles: Number of particles used for SMC. :param float ess_threshold: Effective sample size threshold for SMC. :returns: A dictionary mapping sample site name to tensor value. :rtype: dict """ # Run SMC. model = _SMCModel(self) guide = _SMCGuide(self) for attempt in range(1, 1 + retries): smc = SMCFilter( model, guide, num_particles=num_particles, ess_threshold=ess_threshold, max_plate_nesting=self.max_plate_nesting, ) try: smc.init() for t in range(1, self.duration): smc.step() break except SMCFailed as e: if attempt == retries: raise logger.info("{}. Retrying...".format(e)) continue # Select the most probable hypothesis. # Note this ignores the .finalize() likelihood. i = int(smc.state._log_weights.max(0).indices) init = {key: value[i, 0] for key, value in smc.state.items()} # Fill in sample site values. init = self.generate(init) aux = torch.stack([init[name] for name in self.compartments], dim=0) init["auxiliary"] = clamp(aux, min=0.5, max=self.population - 0.5) return init
# Internal helpers ######################################## def _heuristic(self, haar, **options): with poutine.block(): init_values = self.heuristic(**options) assert isinstance(init_values, dict) assert "auxiliary" in init_values, ".heuristic() did not define auxiliary value" logger.info( "Heuristic init: {}".format( ", ".join( "{}={:0.3g}".format(k, v.item()) for k, v in sorted(init_values.items()) if v.numel() == 1 ) ) ) return init_to_value(values=init_values, fallback=None) def _concat_series(self, samples, trace, forecast=0): """ Concatenate sequential time series into tensors, in-place. :param dict samples: A dictionary of samples. """ time_dim = -2 if self.is_regional else -1 for name in set(self.compartments).union(self.series): pattern = name + "_[0-9]+" series = [] for key in list(samples): if re.match(pattern, key): series.append(samples.pop(key)) if not series: continue assert len(series) == self.duration + forecast series = torch.broadcast_tensors(*map(torch.as_tensor, series)) dim = time_dim - trace.nodes[name + "_0"]["fn"].event_dim if series[0].dim() >= -dim: samples[name] = torch.cat(series, dim=dim) else: samples[name] = torch.stack(series) @lazy_property @torch.no_grad() def _non_compartmental(self): """ A dict mapping name -> (distribution, is_regional) for all non-compartmental sites in :meth:`transition`. For simple models this is often empty; for time-heterogeneous models this may contain time-local latent variables. """ # Trace a simple invocation of .transition(). with torch.no_grad(), poutine.block(): params = self.global_model() prev = self.initialize(params) for name in self.approximate: prev[name + "_approx"] = prev[name] curr = prev.copy() with poutine.trace() as tr: self.transition(params, curr, 0) flows = self.compute_flows(prev, curr, 0) # Extract latent variables that are not compartmental flows. result = OrderedDict() for name, site in tr.trace.iter_stochastic_nodes(): if name in flows or site_is_subsample(site): continue assert name.endswith("_0"), name name = name[:-2] assert name in self.series, name # TODO This supports only the region_plate. For full plate support, # this could be replaced by a self.plate() method as in EasyGuide. is_regional = any(f.name == "region" for f in site["cond_indep_stack"]) result[name] = site["fn"], is_regional return result def _sample_auxiliary(self): """ Sample both compartmental and non-compartmental auxiliary variables. """ C = len(self.compartments) T = self.duration R_shape = getattr(self.population, "shape", ()) # Region shape. # Sample the compartmental continuous reparameterizing variable. shape = (C, T) + R_shape auxiliary = pyro.sample( "auxiliary", dist.Uniform(-0.5, self.population + 0.5) .mask(False) .expand(shape) .to_event(), ) extra_dims = auxiliary.dim() - len(shape) # Sample any non-compartmental time series in batch. non_compartmental = OrderedDict() for name, (fn, is_regional) in self._non_compartmental.items(): fn = dist.ImproperUniform(fn.support, fn.batch_shape, fn.event_shape) shape = (T,) if self.is_regional: shape += R_shape if is_regional else (1,) # Manually expand, avoiding plates to enable HaarReparam and SplitReparam. non_compartmental[name] = pyro.sample(name, fn.expand(shape).to_event()) # Move event dims to time_plate and region_plate dims. if extra_dims: # If inside particle_plate. shape = auxiliary.shape[:1] + auxiliary.shape[extra_dims:] auxiliary = auxiliary.reshape(shape) for name, value in non_compartmental.items(): shape = value.shape[:1] + value.shape[extra_dims:] non_compartmental[name] = value.reshape(shape) return auxiliary, non_compartmental def _transition_bwd(self, params, prev, curr, t): """ Helper to collect probabilty factors from .transition() conditioned on previous and current enumerated states. """ # Run .transition() conditioned on computed flows. cond_data = {"{}_{}".format(k, t): v for k, v in curr.items()} cond_data.update(self.compute_flows(prev, curr, t)) with poutine.condition(data=cond_data): state = prev.copy() self.transition(params, state, t) # Mutates state. # Validate that .transition() matches .compute_flows(). if is_validation_enabled(): for key in self.compartments: if not torch.allclose(state[key], curr[key]): raise ValueError( "Incorrect state['{}'] update in .transition(), " "check that .transition() matches .compute_flows().".format(key) ) def _generative_model(self, forecast=0): """ Forward generative model used for simulation and forecasting. """ # Sample global parameters. params = self.global_model() # Sample initial values. state = self.initialize(params) state = { k: v if isinstance(v, torch.Tensor) else torch.tensor(float(v)) for k, v in state.items() } # Sequentially transition. for t in range(self.duration + forecast): for name in self.approximate: state[name + "_approx"] = state[name] self.transition(params, state, t) with self.region_plate: for name in self.compartments: pyro.deterministic( "{}_{}".format(name, t), state[name], event_dim=0 ) self._clear_plates() def _sequential_model(self): """ Sequential model used to sample latents in the interval [0:duration]. This is compatible with both quantized and relaxed inference. This method is called only inside particle_plate. This method is used only for prediction. """ C = len(self.compartments) T = self.duration R_shape = getattr(self.population, "shape", ()) # Region shape. num_samples = len(next(iter(self.samples.values()))) # Sample global parameters and auxiliary variables. params = self.global_model() auxiliary, non_compartmental = self._sample_auxiliary() # Reshape to accommodate the time_plate below. assert auxiliary.shape == (num_samples, C, T) + R_shape, ( auxiliary.shape, (num_samples, C, T) + R_shape, ) aux = [aux.unbind(2) for aux in auxiliary.unsqueeze(1).unbind(2)] # Sequentially transition. curr = self.initialize(params) for t in poutine.markov(range(T)): with self.region_plate: prev, curr = curr, {} # Extract any non-compartmental variables. for name, value in non_compartmental.items(): curr[name] = value[:, t : t + 1] # Extract and enumerate all compartmental variables. for c, name in enumerate(self.compartments): curr[name] = quantize( "{}_{}".format(name, t), aux[c][t], min=0, max=self.population, num_quant_bins=self.num_quant_bins, ) # Enable approximate inference by using aux as a # non-enumerated proxy for enumerated compartment values. if name in self.approximate: curr[name + "_approx"] = aux[c][t] prev.setdefault(name + "_approx", prev[name]) self._transition_bwd(params, prev, curr, t) self._clear_plates() def _quantized_model(self): """ Quantized vectorized model used for parallel-scan enumerated inference. This method is called only outside particle_plate. """ C = len(self.compartments) T = self.duration Q = self.num_quant_bins R_shape = getattr(self.population, "shape", ()) # Region shape. # Sample global parameters and auxiliary variables. params = self.global_model() auxiliary, non_compartmental = self._sample_auxiliary() # Manually enumerate. curr, logp = quantize_enumerate( auxiliary, min=0, max=self.population, num_quant_bins=self.num_quant_bins ) curr = OrderedDict(zip(self.compartments, curr.unbind(0))) logp = OrderedDict(zip(self.compartments, logp.unbind(0))) curr.update(non_compartmental) # Truncate final value from the right then pad initial value onto the left. init = self.initialize(params) prev = {} for name, value in init.items(): if name in self.compartments: if isinstance(value, torch.Tensor): value = value[..., None] # Because curr is enumerated on the right. prev[name] = cat2( value, curr[name][:-1], dim=-3 if self.is_regional else -2 ) else: # non-compartmental prev[name] = cat2(init[name], curr[name][:-1], dim=-curr[name].dim()) # Reshape to support broadcasting, similar to EnumMessenger. def enum_reshape(tensor, position): assert tensor.size(-1) == Q assert tensor.dim() <= self.max_plate_nesting + 2 tensor = tensor.permute(tensor.dim() - 1, *range(tensor.dim() - 1)) shape = [Q] + [1] * (position + self.max_plate_nesting - (tensor.dim() - 2)) shape.extend(tensor.shape[1:]) return tensor.reshape(shape) for e, name in enumerate(self.compartments): curr[name] = enum_reshape(curr[name], e) logp[name] = enum_reshape(logp[name], e) prev[name] = enum_reshape(prev[name], e + C) # Enable approximate inference by using aux as a non-enumerated proxy # for enumerated compartment values. for name in self.approximate: aux = auxiliary[self.compartments.index(name)] curr[name + "_approx"] = aux prev[name + "_approx"] = cat2( init[name], aux[:-1], dim=-2 if self.is_regional else -1 ) # Record transition factors. with poutine.block(), poutine.trace() as tr: with self.time_plate: t = slice(0, T, 1) # Used to slice data tensors. self._transition_bwd(params, prev, curr, t) tr.trace.compute_log_prob() for name, site in tr.trace.nodes.items(): if site["type"] == "sample": log_prob = site["log_prob"] if log_prob.dim() <= self.max_plate_nesting: # Not enumerated. pyro.factor("transition_" + name, site["log_prob_sum"]) continue if self.is_regional and log_prob.shape[-1:] != R_shape: # Poor man's tensor variable elimination. log_prob = ( log_prob.expand(log_prob.shape[:-1] + R_shape) / R_shape[0] ) logp[name] = site["log_prob"] # Manually perform variable elimination. logp = reduce(operator.add, logp.values()) logp = logp.reshape(Q**C, Q**C, T, -1) # prev, curr, T, batch logp = logp.permute(3, 2, 0, 1).squeeze(0) # batch, T, prev, curr logp = pyro.distributions.hmm._sequential_logmatmulexp( logp ) # batch, prev, curr logp = logp.reshape(-1, Q**C * Q**C).logsumexp(-1).sum() warn_if_nan(logp) pyro.factor("transition", logp) # Apply final likelihood. prev = {name: prev[name + "_approx"] for name in self.approximate} curr = {name: curr[name + "_approx"] for name in self.approximate} with _disallow_latent_variables(".finalize()"): self.finalize(params, prev, curr) self._clear_plates() @set_relaxed_distributions() def _relaxed_model(self): """ Relaxed vectorized model used for continuous inference. This method may be called either inside or outside particle_plate. """ T = self.duration # Sample global parameters and auxiliary variables. params = self.global_model() auxiliary, non_compartmental = self._sample_auxiliary() particle_dims = auxiliary.dim() - (3 if self.is_regional else 2) assert particle_dims in (0, 1) # Split tensors into current state. curr = dict(zip(self.compartments, auxiliary.unbind(particle_dims))) curr.update(non_compartmental) # Truncate final value from the right then pad initial value onto the left. prev = {} for name, value in self.initialize(params).items(): dim = particle_dims - curr[name].dim() t = (slice(None),) * particle_dims + (slice(0, -1),) prev[name] = cat2(value, curr[name][t], dim=dim) # Enable approximate inference. for name in self.approximate: curr[name + "_approx"] = curr[name] prev[name + "_approx"] = prev[name] # Transition. with self.time_plate: t = slice(0, T, 1) # Used to slice data tensors. self._transition_bwd(params, prev, curr, t) # Apply final likelihood. with _disallow_latent_variables(".finalize()"): self.finalize(params, prev, curr) self._clear_plates()
class _SMCModel: """ Helper to initialize a CompartmentalModel to a feasible initial state. """ def __init__(self, model): assert isinstance(model, CompartmentalModel) self.model = model def init(self, state): with poutine.trace() as tr: params = self.model.global_model() for name, site in tr.trace.nodes.items(): if site["type"] == "sample": state[name] = site["value"] self.t = 0 state.update(self.model.initialize(params)) self.step(state) # Take one step since model.initialize is deterministic. def step(self, state): with poutine.block(), poutine.condition(data=state): params = self.model.global_model() with poutine.trace() as tr: # Temporarily extend state with approximations. extended_state = dict(state) for name in self.model.approximate: extended_state[name + "_approx"] = state[name] self.model.transition(params, extended_state, self.t) for name in self.model.approximate: del extended_state[name + "_approx"] state.update(extended_state) for name, site in tr.trace.nodes.items(): if site["type"] == "sample" and not site["is_observed"]: state[name] = site["value"] self.t += 1 class _SMCGuide(_SMCModel): """ Like _SMCModel but does not update state and does not observe. """ def init(self, state): super().init(state.copy()) def step(self, state): with poutine.block(hide_types=["observe"]): super().step(state.copy()) class _HaarSplitReparam: """ Wrapper around ``HaarReparam`` and ``SplitReparam`` to additionally convert sample dicts between user-facing and auxiliary coordinates. """ def __init__(self, split, duration, dims, supports): assert 0 <= split < duration self.split = split self.duration = duration self.dims = dims self.supports = supports def __bool__(self): return True def reparam(self, model): """ Wrap a model with ``poutine.reparam``. """ # Transform to Haar coordinates. config = {} for name, dim in self.dims.items(): config[name] = HaarReparam(dim=dim, flip=True) model = poutine.reparam(model, config) if self.split: # Split into low- and high-frequency parts. splits = [self.split, self.duration - self.split] config = {} for name, dim in self.dims.items(): config[name + "_haar"] = SplitReparam(splits, dim=dim) model = poutine.reparam(model, config) return model def aux_to_user(self, samples): """ Convert from auxiliary samples to user-facing samples, in-place. """ if self.split: # Transform back from SplitReparam coordinates. for name, dim in self.dims.items(): samples[name + "_haar"] = torch.cat( [ samples.pop(name + "_haar_split_0"), samples.pop(name + "_haar_split_1"), ], dim=dim, ) # Transform back from Haar coordinates. for name, dim in self.dims.items(): x = samples.pop(name + "_haar") x = HaarTransform(dim=dim, flip=True).inv(x) x = biject_to(self.supports[name])(x) samples[name] = x