Source code for pyro.contrib.tracking.extended_kalman_filter

import torch
from torch.distributions.utils import lazy_property

import pyro.distributions as dist
from pyro.distributions.util import eye_like


[docs]class EKFState(object): ''' State-Centric EKF (Extended Kalman Filter) for use with either an NCP (Nearly-Constant Position) or NCV (Nearly-Constant Velocity) target dynamic model. Stores a target dynamic model, state estimate, and state time. Incoming ``Measurement`` provide sensor information for updates. .. warning:: For efficiency, the dynamic model is only shallow-copied. Make a deep copy outside as necessary to protect against unexpected changes. :param dynamic_model: target dynamic model. :param mean: mean of target state estimate. :param cov: covariance of target state estimate. :param time: time of state estimate. ''' def __init__(self, dynamic_model, mean, cov, time=None, frame_num=None): self._dynamic_model = dynamic_model self._mean = mean self._cov = cov if time is None and frame_num is None: raise ValueError('Must provide time or frame_num!') self._time = time self._frame_num = frame_num @property def dynamic_model(self): ''' Dynamic model access. ''' return self._dynamic_model @property def dimension(self): ''' Native state dimension access. ''' return self._dynamic_model.dimension @property def mean(self): ''' Native state estimate mean access. ''' return self._mean @property def cov(self): ''' Native state estimate covariance access. ''' return self._cov @property def dimension_pv(self): ''' PV state dimension access. ''' return self._dynamic_model.dimension_pv
[docs] @lazy_property def mean_pv(self): ''' Compute and return cached PV state estimate mean. ''' return self._dynamic_model.mean2pv(self._mean)
[docs] @lazy_property def cov_pv(self): ''' Compute and return cached PV state estimate covariance. ''' return self._dynamic_model.cov2pv(self._cov)
@property def time(self): ''' Continuous State time access. ''' return self._time @property def frame_num(self): ''' Discrete State time access. ''' return self._frame_num
[docs] def predict(self, dt=None, destination_time=None, destination_frame_num=None): ''' Use dynamic model to predict (aka propagate aka integrate) state estimate in-place. :param dt: time to integrate over. The state time will be automatically incremented this amount unless you provide ``destination_time``. Using ``destination_time`` may be preferable for prevention of roundoff error accumulation. :param destination_time: optional value to set continuous state time to after integration. If this is not provided, then `destination_frame_num` must be. :param destination_frame_num: optional value to set discrete state time to after integration. If this is not provided, then `destination_frame_num` must be. ''' assert (dt is None) ^ (destination_time is None) if dt is None: dt = destination_time - self._time elif destination_time is None: destination_time = self._time + dt pred_mean = self._dynamic_model(self._mean, dt) F = self._dynamic_model.jacobian(dt) Q = self._dynamic_model.process_noise_cov(dt) pred_cov = F.mm(self._cov).mm(F.transpose(-1, -2)) + Q if destination_time is None and destination_frame_num is None: raise ValueError('destination_time or destination_frame_num must be specified!') return EKFState(self._dynamic_model, pred_mean, pred_cov, destination_time, destination_frame_num)
[docs] def innovation(self, measurement): ''' Compute and return the innovation that a measurement would induce if it were used for an update, but don't actually perform the update. Assumes state and measurement are time-aligned. Useful for computing Chi^2 stats and likelihoods. :param measurement: measurement :return: Innovation mean and covariance of hypothetical update. :rtype: tuple(``torch.Tensor``, ``torch.Tensor``) ''' assert self._time == measurement.time, \ 'State time and measurement time must be aligned!' # Compute innovation. x_pv = self._dynamic_model.mean2pv(self._mean) H = measurement.jacobian(x_pv)[:, :self.dimension] R = measurement.cov z = measurement.mean z_predicted = measurement(x_pv) dz = measurement.geodesic_difference(z, z_predicted) S = H.mm(self._cov).mm(H.transpose(-1, -2)) + R # innovation cov return dz, S
[docs] def log_likelihood_of_update(self, measurement): ''' Compute and return the likelihood of a potential update, but don't actually perform the update. Assumes state and measurement are time- aligned. Useful for gating and calculating costs in assignment problems for data association. :param: measurement. :return: Likelihood of hypothetical update. ''' dz, S = self.innovation(measurement) return dist.MultivariateNormal(torch.zeros(S.size(-1), dtype=S.dtype, device=S.device), S).log_prob(dz)
[docs] def update(self, measurement): ''' Use measurement to update state estimate in-place and return innovation. The innovation is useful, e.g., for evaluating filter consistency or updating model likelihoods when the ``EKFState`` is part of an ``IMMFState``. :param: measurement. :returns: EKF State, Innovation mean and covariance. ''' if self._time is not None: assert self._time == measurement.time, \ 'State time and measurement time must be aligned!' if self._frame_num is not None: assert self._frame_num == measurement.frame_num, \ 'State time and measurement time must be aligned!' x = self._mean x_pv = self._dynamic_model.mean2pv(x) P = self.cov H = measurement.jacobian(x_pv)[:, :self.dimension] R = measurement.cov z = measurement.mean z_predicted = measurement(x_pv) dz = measurement.geodesic_difference(z, z_predicted) S = H.mm(P).mm(H.transpose(-1, -2)) + R # innovation cov K_prefix = self._cov.mm(H.transpose(-1, -2)) dx = K_prefix.mm(torch.solve(dz.unsqueeze(1), S)[0]).squeeze(1) # K*dz x = self._dynamic_model.geodesic_difference(x, -dx) I = eye_like(x, self._dynamic_model.dimension) # noqa: E741 ImKH = I - K_prefix.mm(torch.solve(H, S)[0]) # *Joseph form* of covariance update for numerical stability. P = ImKH.mm(self.cov).mm(ImKH.transpose(-1, -2)) \ + K_prefix.mm(torch.solve((K_prefix.mm(torch.solve(R, S)[0])).transpose(-1, -2), S)[0]) pred_mean = x pred_cov = P state = EKFState(self._dynamic_model, pred_mean, pred_cov, self._time, self._frame_num) return state, (dz, S)