Forecasting¶
pyro.contrib.forecast
is a lightweight framework for experimenting with a
restricted class of time series models and inference algorithms using familiar
Pyro modeling syntax and PyTorch neural networks.
Models include hierarchical multivariate heavy-tailed time series of ~1000 time
steps and ~1000 separate series. Inference combines subsample-compatible
variational inference with Gaussian variable elimination based on the
GaussianHMM
class. Inference using Hamiltonian Monte Carlo
sampling is also supported with HMCForecaster
.
Forecasts are in the form of joint posterior samples at multiple future time steps.
Hierarchical models use the familiar plate
syntax for
general hierarchical modeling in Pyro. Plates can be subsampled, enabling
training of joint models over thousands of time series. Multivariate
observations are handled via multivariate likelihoods like
MultivariateNormal
, GaussianHMM
, or
LinearHMM
. Heavy tailed models are possible by
using StudentT
or
Stable
likelihoods, possibly together with
LinearHMM
and reparameterizers including
StudentTReparam
,
StableReparam
, and
LinearHMMReparam
.
Seasonality can be handled using the helpers
periodic_repeat()
,
periodic_cumsum()
, and
periodic_features()
.
See pyro.contrib.timeseries
for ways to construct temporal Gaussian processes useful as likelihoods.
For example usage see:
Forecaster Interface¶
- class ForecastingModel[source]¶
Bases:
pyro.nn.module.PyroModule
Abstract base class for forecasting models.
Derived classes must implement the
model()
method.- abstract model(zero_data, covariates)[source]¶
Generative model definition.
Implementations must call the
predict()
method exactly once.Implementations must draw all time-dependent noise inside the
time_plate()
. The prediction passed topredict()
must be a deterministic function of noise tensors that are independent over time. This requirement is slightly more general than state space models.- Parameters
zero_data (Tensor) – A zero tensor like the input data, but extended to the duration of the
time_plate()
. This allows models to depend on the shape and device of data but not its value.covariates (Tensor) – A tensor of covariates with time dimension -2.
- Returns
Return value is ignored.
- property time_plate¶
Helper to create a
pyro.plate
over time.- Returns
A plate named “time” with size
covariates.size(-2)
anddim=-1
. This is available only during model execution.- Return type
plate
- predict(noise_dist, prediction)[source]¶
Prediction function, to be called by
model()
implementations.This should be called outside of the
time_plate()
.This is similar to an observe statement in Pyro:
pyro.sample("residual", noise_dist, obs=(data - prediction))
but with (1) additional reshaping logic to allow time-dependent
noise_dist
(most often aGaussianHMM
or variant); and (2) additional logic to allow only a partial observation and forecast the remaining data.- Parameters
noise_dist (Distribution) – A noise distribution with
.event_dim in {0,1,2}
.noise_dist
is typically zero-mean or zero-median or zero-mode or somehow centered.prediction (Tensor) – A prediction for the data. This should have the same shape as
data
, but broadcastable to full duration of thecovariates
.
- class Forecaster(model, data, covariates, *, guide=None, init_loc_fn=<function init_to_sample>, init_scale=0.1, create_plates=None, optim=None, learning_rate=0.01, betas=(0.9, 0.99), learning_rate_decay=0.1, clip_norm=10.0, time_reparam=None, dct_gradients=False, subsample_aware=False, num_steps=1001, num_particles=1, vectorize_particles=True, warm_start=False, log_every=100)[source]¶
Bases:
torch.nn.modules.module.Module
Forecaster for a
ForecastingModel
using variational inference.On initialization, this fits a distribution using variational inference over latent variables and exact inference over the noise distribution, typically a
GaussianHMM
or variant.After construction this can be called to generate sample forecasts.
- Variables
losses (list) – A list of losses recorded during training, typically used to debug convergence. Defined by
loss = -elbo / data.numel()
.- Parameters
model (ForecastingModel) – A forecasting model subclass instance.
data (Tensor) – A tensor dataset with time dimension -2.
covariates (Tensor) – A tensor of covariates with time dimension -2. For models not using covariates, pass a shaped empty tensor
torch.empty(duration, 0)
.guide (PyroModule) – Optional guide instance. Defaults to a
AutoNormal
.init_loc_fn (callable) – A per-site initialization function for the
AutoNormal
guide. Defaults toinit_to_sample()
. See Initialization section for available functions.init_scale (float) – Initial uncertainty scale of the
AutoNormal
guide.create_plates (callable) – An optional function to create plates for subsampling with the
AutoNormal
guide.optim (PyroOptim) – An optional Pyro optimizer. Defaults to a freshly constructed
DCTAdam
.betas (tuple) – Coefficients for running averages used by
DCTAdam
.learning_rate_decay (float) – Learning rate decay used by
DCTAdam
. Note this is the total decay over allnum_steps
, not the per-step decay factor.clip_norm (float) – Norm used for gradient clipping during optimization. Defaults to 10.0.
time_reparam (str) – If not None (default), reparameterize all time-dependent variables via the Haar wavelet transform (if “haar”) or the discrete cosine transform (if “dct”).
dct_gradients (bool) – Whether to discrete cosine transform gradients in
DCTAdam
. Defaults to False.subsample_aware (bool) – whether to update gradient statistics only for those elements that appear in a subsample. This is used by
DCTAdam
.num_particles (int) – Number of particles used to compute the
ELBO
.vectorize_particles (bool) – If
num_particles > 1
, determines whether to vectorize computation of theELBO
. Defaults to True. Set to False for models with dynamic control flow.warm_start (bool) – Whether to warm start parameters from a smaller time window. Note this may introduce statistical leakage; usage is recommended for model exploration purposes only and should be disabled when publishing metrics.
log_every (int) – Number of training steps between logging messages.
- __call__(data, covariates, num_samples, batch_size=None)[source]¶
Samples forecasted values of data for time steps in
[t1,t2)
, wheret1 = data.size(-2)
is the duration of observed data andt2 = covariates.size(-2)
is the extended duration of covariates. For example to forecast 7 days forward conditioned on 30 days of observations, sett1=30
andt2=37
.- Parameters
data (Tensor) – A tensor dataset with time dimension -2.
covariates (Tensor) – A tensor of covariates with time dimension -2. For models not using covariates, pass a shaped empty tensor
torch.empty(duration, 0)
.num_samples (int) – The number of samples to generate.
batch_size (int) – Optional batch size for sampling. This is useful for generating many samples from models with large memory footprint. Defaults to
num_samples
.
- Returns
A batch of joint posterior samples of shape
(num_samples,1,...,1) + data.shape[:-2] + (t2-t1,data.size(-1))
, where the1
’s are inserted to avoid conflict with model plates.- Return type
- class HMCForecaster(model, data, covariates=None, *, num_warmup=1000, num_samples=1000, num_chains=1, time_reparam=None, dense_mass=False, jit_compile=False, max_tree_depth=10)[source]¶
Bases:
torch.nn.modules.module.Module
Forecaster for a
ForecastingModel
using Hamiltonian Monte Carlo.On initialization, this will run
NUTS
sampler to get posterior samples of the model.After construction, this can be called to generate sample forecasts.
- Parameters
model (ForecastingModel) – A forecasting model subclass instance.
data (Tensor) – A tensor dataset with time dimension -2.
covariates (Tensor) – A tensor of covariates with time dimension -2. For models not using covariates, pass a shaped empty tensor
torch.empty(duration, 0)
.num_warmup (int) – number of MCMC warmup steps.
num_samples (int) – number of MCMC samples.
num_chains (int) – number of parallel MCMC chains.
dense_mass (bool) – a flag to control whether the mass matrix is dense or diagonal. Defaults to False.
time_reparam (str) – If not None (default), reparameterize all time-dependent variables via the Haar wavelet transform (if “haar”) or the discrete cosine transform (if “dct”).
jit_compile (bool) – whether to use the PyTorch JIT to trace the log density computation, and use this optimized executable trace in the integrator. Defaults to False.
max_tree_depth (int) – Max depth of the binary tree created during the doubling scheme of the
NUTS
sampler. Defaults to 10.
- __call__(data, covariates, num_samples, batch_size=None)[source]¶
Samples forecasted values of data for time steps in
[t1,t2)
, wheret1 = data.size(-2)
is the duration of observed data andt2 = covariates.size(-2)
is the extended duration of covariates. For example to forecast 7 days forward conditioned on 30 days of observations, sett1=30
andt2=37
.- Parameters
data (Tensor) – A tensor dataset with time dimension -2.
covariates (Tensor) – A tensor of covariates with time dimension -2. For models not using covariates, pass a shaped empty tensor
torch.empty(duration, 0)
.num_samples (int) – The number of samples to generate.
batch_size (int) – Optional batch size for sampling. This is useful for generating many samples from models with large memory footprint. Defaults to
num_samples
.
- Returns
A batch of joint posterior samples of shape
(num_samples,1,...,1) + data.shape[:-2] + (t2-t1,data.size(-1))
, where the1
’s are inserted to avoid conflict with model plates.- Return type
Evaluation¶
- eval_mae(pred, truth)[source]¶
Evaluate mean absolute error, using sample median as point estimate.
- Parameters
pred (torch.Tensor) – Forecasted samples.
truth (torch.Tensor) – Ground truth.
- Return type
- eval_rmse(pred, truth)[source]¶
Evaluate root mean squared error, using sample mean as point estimate.
- Parameters
pred (torch.Tensor) – Forecasted samples.
truth (torch.Tensor) – Ground truth.
- Return type
- eval_crps(pred, truth)[source]¶
Evaluate continuous ranked probability score, averaged over all data elements.
References
- [1] Tilmann Gneiting, Adrian E. Raftery (2007)
Strictly Proper Scoring Rules, Prediction, and Estimation https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf
- Parameters
pred (torch.Tensor) – Forecasted samples.
truth (torch.Tensor) – Ground truth.
- Return type
- backtest(data, covariates, model_fn, *, forecaster_fn=<class 'pyro.contrib.forecast.forecaster.Forecaster'>, metrics=None, transform=None, train_window=None, min_train_window=1, test_window=None, min_test_window=1, stride=1, seed=1234567890, num_samples=100, batch_size=None, forecaster_options={})[source]¶
Backtest a forecasting model on a moving window of (train,test) data.
- Parameters
data (Tensor) – A tensor dataset with time dimension -2.
covariates (Tensor) – A tensor of covariates with time dimension -2. For models not using covariates, pass a shaped empty tensor
torch.empty(duration, 0)
.model_fn (callable) – Function that returns an
ForecastingModel
object.forecaster_fn (callable) – Function that returns a forecaster object (for example,
Forecaster
orHMCForecaster
) given arguments model, training data, training covariates and keyword arguments defined in forecaster_options.metrics (dict) – A dictionary mapping metric name to metric function. The metric function should input a forecast
pred
and groundtruth
and can output anything, often a number. Example metrics include:eval_mae()
,eval_rmse()
, andeval_crps()
.transform (callable) – An optional transform to apply before computing metrics. If provided this will be applied as
pred, truth = transform(pred, truth)
.train_window (int) – Size of the training window. Be default trains from beginning of data. This must be None if forecaster is
Forecaster
andforecaster_options["warm_start"]
is true.min_train_window (int) – If
train_window
is None, this specifies the min training window size. Defaults to 1.test_window (int) – Size of the test window. By default forecasts to end of data.
min_test_window (int) – If
test_window
is None, this specifies the min test window size. Defaults to 1.stride (int) – Optional stride for test/train split. Defaults to 1.
seed (int) – Random number seed.
num_samples (int) – Number of samples for forecast. Defaults to 100.
batch_size (int) – Batch size for forecast sampling. Defaults to
num_samples
.forecaster_options (dict or callable) – Options dict to pass to forecaster, or callable inputting time window
t0,t1,t2
and returning such a dict. SeeForecaster
for details.
- Returns
A list of dictionaries of evaluation data. Caller is responsible for aggregating the per-window metrics. Dictionary keys include: train begin time “t0”, train/test split time “t1”, test end time “t2”, “seed”, “num_samples”, “train_walltime”, “test_walltime”, and one key for each metric.
- Return type