Source code for pyro.contrib.tracking.assignment

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import itertools
import math
import numbers

import torch

import pyro.distributions as dist
from pyro.util import warn_if_nan


def _product(factors):
    result = 1.
    for factor in factors:
        result = result * factor
    return result


def _exp(value):
    if isinstance(value, numbers.Number):
        return math.exp(value)
    return value.exp()


[docs]class MarginalAssignment: """ Computes marginal data associations between objects and detections. This assumes that each detection corresponds to zero or one object, and each object corresponds to zero or more detections. Specifically this does not assume detections have been partitioned into frames of mutual exclusion as is common in 2-D assignment problems. :param torch.Tensor exists_logits: a tensor of shape ``[num_objects]`` representing per-object factors for existence of each potential object. :param torch.Tensor assign_logits: a tensor of shape ``[num_detections, num_objects]`` representing per-edge factors of assignment probability, where each edge denotes that a given detection associates with a single object. :param int bp_iters: optional number of belief propagation iterations. If unspecified or ``None`` an expensive exact algorithm will be used. :ivar int num_detections: the number of detections :ivar int num_objects: the number of (potentially existing) objects :ivar pyro.distributions.Bernoulli exists_dist: a mean field posterior distribution over object existence. :ivar pyro.distributions.Categorical assign_dist: a mean field posterior distribution over the object (or None) to which each detection associates. This has ``.event_shape == (num_objects + 1,)`` where the final element denotes spurious detection, and ``.batch_shape == (num_frames, num_detections)``. """ def __init__(self, exists_logits, assign_logits, bp_iters=None): assert exists_logits.dim() == 1, exists_logits.shape assert assign_logits.dim() == 2, assign_logits.shape assert assign_logits.shape[-1] == exists_logits.shape[-1] self.num_detections, self.num_objects = assign_logits.shape # Clamp to avoid NANs. exists_logits = exists_logits.clamp(min=-40, max=40) assign_logits = assign_logits.clamp(min=-40, max=40) # This does all the work. if bp_iters is None: exists, assign = compute_marginals(exists_logits, assign_logits) else: exists, assign = compute_marginals_bp(exists_logits, assign_logits, bp_iters) # Wrap the results in Distribution objects. # This adds a final logit=0 element denoting spurious detection. padded_assign = torch.nn.functional.pad(assign, (0, 1), "constant", 0.0) self.assign_dist = dist.Categorical(logits=padded_assign) self.exists_dist = dist.Bernoulli(logits=exists)
[docs]class MarginalAssignmentSparse: """ A cheap sparse version of :class:`MarginalAssignment`. :param int num_detections: the number of detections :param int num_objects: the number of (potentially existing) objects :param torch.LongTensor edges: a ``[2, num_edges]``-shaped tensor of (detection, object) index pairs specifying feasible associations. :param torch.Tensor exists_logits: a tensor of shape ``[num_objects]`` representing per-object factors for existence of each potential object. :param torch.Tensor assign_logits: a tensor of shape ``[num_edges]`` representing per-edge factors of assignment probability, where each edge denotes that a given detection associates with a single object. :param int bp_iters: optional number of belief propagation iterations. If unspecified or ``None`` an expensive exact algorithm will be used. :ivar int num_detections: the number of detections :ivar int num_objects: the number of (potentially existing) objects :ivar pyro.distributions.Bernoulli exists_dist: a mean field posterior distribution over object existence. :ivar pyro.distributions.Categorical assign_dist: a mean field posterior distribution over the object (or None) to which each detection associates. This has ``.event_shape == (num_objects + 1,)`` where the final element denotes spurious detection, and ``.batch_shape == (num_frames, num_detections)``. """ def __init__(self, num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters): assert edges.dim() == 2, edges.shape assert edges.shape[0] == 2, edges.shape assert exists_logits.shape == (num_objects,), exists_logits.shape assert assign_logits.shape == edges.shape[1:], assign_logits.shape self.num_objects = num_objects self.num_detections = num_detections self.edges = edges # Clamp to avoid NANs. exists_logits = exists_logits.clamp(min=-40, max=40) assign_logits = assign_logits.clamp(min=-40, max=40) # This does all the work. exists, assign = compute_marginals_sparse_bp( num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters) # Wrap the results in Distribution objects. # This adds a final logit=0 element denoting spurious detection. padded_assign = torch.full((num_detections, num_objects + 1), -float('inf'), dtype=assign.dtype, device=assign.device) padded_assign[:, -1] = 0 padded_assign[edges[0], edges[1]] = assign self.assign_dist = dist.Categorical(logits=padded_assign) self.exists_dist = dist.Bernoulli(logits=exists)
[docs]class MarginalAssignmentPersistent: """ This computes marginal distributions of a multi-frame multi-object data association problem with an unknown number of persistent objects. The inputs are factors in a factor graph (existence probabilites for each potential object and assignment probabilities for each object-detection pair), and the outputs are marginal distributions of posterior existence probability of each potential object and posterior assignment probabilites of each object-detection pair. This assumes a shared (maximum) number of detections per frame; to handle variable number of detections, simply set corresponding elements of ``assign_logits`` to ``-float('inf')``. :param torch.Tensor exists_logits: a tensor of shape ``[num_objects]`` representing per-object factors for existence of each potential object. :param torch.Tensor assign_logits: a tensor of shape ``[num_frames, num_detections, num_objects]`` representing per-edge factors of assignment probability, where each edge denotes that at a given time frame a given detection associates with a single object. :param int bp_iters: optional number of belief propagation iterations. If unspecified or ``None`` an expensive exact algorithm will be used. :param float bp_momentum: optional momentum to use for belief propagation. Should be in the interval ``[0,1)``. :ivar int num_frames: the number of time frames :ivar int num_detections: the (maximum) number of detections per frame :ivar int num_objects: the number of (potentially existing) objects :ivar pyro.distributions.Bernoulli exists_dist: a mean field posterior distribution over object existence. :ivar pyro.distributions.Categorical assign_dist: a mean field posterior distribution over the object (or None) to which each detection associates. This has ``.event_shape == (num_objects + 1,)`` where the final element denotes spurious detection, and ``.batch_shape == (num_frames, num_detections)``. """ def __init__(self, exists_logits, assign_logits, bp_iters=None, bp_momentum=0.5): assert exists_logits.dim() == 1, exists_logits.shape assert assign_logits.dim() == 3, assign_logits.shape assert assign_logits.shape[-1] == exists_logits.shape[-1] self.num_frames, self.num_detections, self.num_objects = assign_logits.shape # Clamp to avoid NANs. exists_logits = exists_logits.clamp(min=-40, max=40) assign_logits = assign_logits.clamp(min=-40, max=40) # This does all the work. if bp_iters is None: exists, assign = compute_marginals_persistent(exists_logits, assign_logits) else: exists, assign = compute_marginals_persistent_bp( exists_logits, assign_logits, bp_iters, bp_momentum) # Wrap the results in Distribution objects. # This adds a final logit=0 element denoting spurious detection. padded_assign = torch.nn.functional.pad(assign, (0, 1), "constant", 0.0) self.assign_dist = dist.Categorical(logits=padded_assign) self.exists_dist = dist.Bernoulli(logits=exists) assert self.assign_dist.batch_shape == (self.num_frames, self.num_detections) assert self.exists_dist.batch_shape == (self.num_objects,)
[docs]def compute_marginals(exists_logits, assign_logits): """ This implements exact inference of pairwise marginals via enumeration. This is very expensive and is only useful for testing. See :class:`MarginalAssignment` for args and problem description. """ num_detections, num_objects = assign_logits.shape assert exists_logits.shape == (num_objects,) dtype = exists_logits.dtype device = exists_logits.device exists_probs = torch.zeros(2, num_objects, dtype=dtype, device=device) # [not exist, exist] assign_probs = torch.zeros(num_detections, num_objects + 1, dtype=dtype, device=device) for assign in itertools.product(range(num_objects + 1), repeat=num_detections): assign_part = sum(assign_logits[j, i] for j, i in enumerate(assign) if i < num_objects) for exists in itertools.product(*[[1] if i in assign else [0, 1] for i in range(num_objects)]): exists_part = sum(exists_logits[i] for i, e in enumerate(exists) if e) prob = _exp(exists_part + assign_part) for i, e in enumerate(exists): exists_probs[e, i] += prob for j, i in enumerate(assign): assign_probs[j, i] += prob # Convert from probs to logits. exists = exists_probs.log() assign = assign_probs.log() exists = exists[1] - exists[0] assign = assign[:, :-1] - assign[:, -1:] warn_if_nan(exists, 'exists') warn_if_nan(assign, 'assign') return exists, assign
[docs]def compute_marginals_bp(exists_logits, assign_logits, bp_iters): """ This implements approximate inference of pairwise marginals via loopy belief propagation, adapting the approach of [1]. See :class:`MarginalAssignment` for args and problem description. [1] Jason L. Williams, Roslyn A. Lau (2014) Approximate evaluation of marginal association probabilities with belief propagation https://arxiv.org/abs/1209.6299 """ message_e_to_a = torch.zeros_like(assign_logits) message_a_to_e = torch.zeros_like(assign_logits) for i in range(bp_iters): message_e_to_a = -(message_a_to_e - message_a_to_e.sum(0, True) - exists_logits).exp().log1p() joint = (assign_logits + message_e_to_a).exp() message_a_to_e = (assign_logits - torch.log1p(joint.sum(1, True) - joint)).exp().log1p() warn_if_nan(message_e_to_a, 'message_e_to_a iter {}'.format(i)) warn_if_nan(message_a_to_e, 'message_a_to_e iter {}'.format(i)) # Convert from probs to logits. exists = exists_logits + message_a_to_e.sum(0) assign = assign_logits + message_e_to_a warn_if_nan(exists, 'exists') warn_if_nan(assign, 'assign') return exists, assign
[docs]def compute_marginals_sparse_bp(num_objects, num_detections, edges, exists_logits, assign_logits, bp_iters): """ This implements approximate inference of pairwise marginals via loopy belief propagation, adapting the approach of [1]. See :class:`MarginalAssignmentSparse` for args and problem description. [1] Jason L. Williams, Roslyn A. Lau (2014) Approximate evaluation of marginal association probabilities with belief propagation https://arxiv.org/abs/1209.6299 """ exists_factor = exists_logits[edges[1]] def sparse_sum(x, dim, keepdim=False): assert dim in (0, 1) x = (torch.zeros([num_objects, num_detections][dim], dtype=x.dtype, device=x.device) .scatter_add_(0, edges[1 - dim], x)) if keepdim: x = x[edges[1 - dim]] return x message_e_to_a = torch.zeros_like(assign_logits) message_a_to_e = torch.zeros_like(assign_logits) for i in range(bp_iters): message_e_to_a = -(message_a_to_e - sparse_sum(message_a_to_e, 0, True) - exists_factor).exp().log1p() joint = (assign_logits + message_e_to_a).exp() message_a_to_e = (assign_logits - torch.log1p(sparse_sum(joint, 1, True) - joint)).exp().log1p() warn_if_nan(message_e_to_a, 'message_e_to_a iter {}'.format(i)) warn_if_nan(message_a_to_e, 'message_a_to_e iter {}'.format(i)) # Convert from probs to logits. exists = exists_logits + sparse_sum(message_a_to_e, 0) assign = assign_logits + message_e_to_a warn_if_nan(exists, 'exists') warn_if_nan(assign, 'assign') return exists, assign
[docs]def compute_marginals_persistent(exists_logits, assign_logits): """ This implements exact inference of pairwise marginals via enumeration. This is very expensive and is only useful for testing. See :class:`MarginalAssignmentPersistent` for args and problem description. """ num_frames, num_detections, num_objects = assign_logits.shape assert exists_logits.shape == (num_objects,) dtype = exists_logits.dtype device = exists_logits.device total = 0 exists_probs = torch.zeros(num_objects, dtype=dtype, device=device) assign_probs = torch.zeros(num_frames, num_detections, num_objects, dtype=dtype, device=device) for exists in itertools.product([0, 1], repeat=num_objects): exists = [i for i, e in enumerate(exists) if e] exists_part = _exp(sum(exists_logits[i] for i in exists)) # The remaining variables are conditionally independent conditioned on exists. assign_parts = [] assign_sums = [] for t in range(num_frames): assign_map = {} for n in range(1 + min(len(exists), num_detections)): for objects in itertools.combinations(exists, n): for detections in itertools.permutations(range(num_detections), n): assign = tuple(zip(objects, detections)) assign_map[assign] = _exp(sum(assign_logits[t, j, i] for i, j in assign)) assign_parts.append(assign_map) assign_sums.append(sum(assign_map.values())) prob = exists_part * _product(assign_sums) total += prob for i in exists: exists_probs[i] += prob for t in range(num_frames): other_part = exists_part * _product(assign_sums[:t] + assign_sums[t + 1:]) for assign, assign_part in assign_parts[t].items(): prob = other_part * assign_part for i, j in assign: assign_probs[t, j, i] += prob # Convert from probs to logits. exists = exists_probs.log() - (total - exists_probs).log() assign = assign_probs.log() - (total - assign_probs.sum(-1, True)).log() warn_if_nan(exists, 'exists') warn_if_nan(assign, 'assign') return exists, assign
[docs]def compute_marginals_persistent_bp(exists_logits, assign_logits, bp_iters, bp_momentum=0.5): """ This implements approximate inference of pairwise marginals via loopy belief propagation, adapting the approach of [1], [2]. See :class:`MarginalAssignmentPersistent` for args and problem description. [1] Jason L. Williams, Roslyn A. Lau (2014) Approximate evaluation of marginal association probabilities with belief propagation https://arxiv.org/abs/1209.6299 [2] Ryan Turner, Steven Bottone, Bhargav Avasarala (2014) A Complete Variational Tracker https://papers.nips.cc/paper/5572-a-complete-variational-tracker.pdf """ # This implements forward-backward message passing among three sets of variables: # # a[t,j] ~ Categorical(num_objects + 1), detection -> object assignment # b[t,i] ~ Categorical(num_detections + 1), object -> detection assignment # e[i] ~ Bernonulli, whether each object exists # # Only assign = a and exists = e are returned. assert 0 <= bp_momentum < 1, bp_momentum old, new = bp_momentum, 1 - bp_momentum num_frames, num_detections, num_objects = assign_logits.shape dtype = assign_logits.dtype device = assign_logits.device message_b_to_a = torch.zeros(num_frames, num_detections, num_objects, dtype=dtype, device=device) message_a_to_b = torch.zeros(num_frames, num_detections, num_objects, dtype=dtype, device=device) message_b_to_e = torch.zeros(num_frames, num_objects, dtype=dtype, device=device) message_e_to_b = torch.zeros(num_frames, num_objects, dtype=dtype, device=device) for i in range(bp_iters): odds_a = (assign_logits + message_b_to_a).exp() message_a_to_b = (old * message_a_to_b + new * (assign_logits - (odds_a.sum(2, True) - odds_a).log1p())) message_b_to_e = (old * message_b_to_e + new * message_a_to_b.exp().sum(1).log1p()) message_e_to_b = (old * message_e_to_b + new * (exists_logits + message_b_to_e.sum(0) - message_b_to_e)) odds_b = message_a_to_b.exp() message_b_to_a = (old * message_b_to_a - new * ((-message_e_to_b).exp().unsqueeze(1) + (1 + odds_b.sum(1, True) - odds_b)).log()) warn_if_nan(message_a_to_b, 'message_a_to_b iter {}'.format(i)) warn_if_nan(message_b_to_e, 'message_b_to_e iter {}'.format(i)) warn_if_nan(message_e_to_b, 'message_e_to_b iter {}'.format(i)) warn_if_nan(message_b_to_a, 'message_b_to_a iter {}'.format(i)) # Convert from probs to logits. exists = exists_logits + message_b_to_e.sum(0) assign = assign_logits + message_b_to_a warn_if_nan(exists, 'exists') warn_if_nan(assign, 'assign') return exists, assign