Source code for pyro.contrib.timeseries.base

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from pyro.nn import PyroModule, pyro_method


[docs]class TimeSeriesModel(PyroModule): """ Base class for univariate and multivariate time series models. """
[docs] @pyro_method def log_prob(self, targets): """ Log probability function. :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim`` is the dimension of the real-valued ``targets`` at each time step :returns torch.Tensor: A 0-dimensional log probability for the case of properly multivariate time series models in which the output dimensions are correlated; otherwise returns a 1-dimensional tensor of log probabilities for batched univariate time series models. """ raise NotImplementedError
[docs] @pyro_method def forecast(self, targets, dts): """ :param torch.Tensor targets: A 2-dimensional tensor of real-valued targets of shape ``(T, obs_dim)``, where ``T`` is the length of the time series and ``obs_dim`` is the dimension of the real-valued targets at each time step. These represent the training data that are conditioned on for the purpose of making forecasts. :param torch.Tensor dts: A 1-dimensional tensor of times to forecast into the future, with zero corresponding to the time of the final target ``targets[-1]``. :returns torch.distributions.Distribution: Returns a predictive distribution with batch shape ``(S,)`` and event shape ``(obs_dim,)``, where ``S`` is the size of ``dts``. That is, the resulting predictive distributions do not encode correlations between distinct times in ``dts``. """ raise NotImplementedError
[docs] def get_dist(self): """ Get a :class:`~pyro.distributions.Distribution` object corresponding to this time series model. Often this is a :class:`~pyro.distributions.GaussianHMM`. """ raise NotImplementedError