Source code for pyro.contrib.tracking.hashing

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import heapq
import itertools
from collections import defaultdict
from numbers import Number

import torch


[docs]class LSH: """ Implements locality-sensitive hashing for low-dimensional euclidean space. Allows to efficiently find neighbours of a point. Provides 2 guarantees: - Difference between coordinates of points not returned by :meth:`nearby` and input point is larger than ``radius``. - Difference between coordinates of points returned by :meth:`nearby` and input point is smaller than 2 ``radius``. Example: >>> radius = 1 >>> lsh = LSH(radius) >>> a = torch.tensor([-0.51, -0.51]) # hash(a)=(-1,-1) >>> b = torch.tensor([-0.49, -0.49]) # hash(a)=(0,0) >>> c = torch.tensor([1.0, 1.0]) # hash(b)=(1,1) >>> lsh.add('a', a) >>> lsh.add('b', b) >>> lsh.add('c', c) >>> # even though c is within 2radius of a >>> lsh.nearby('a') # doctest: +SKIP {'b'} >>> lsh.nearby('b') # doctest: +SKIP {'a', 'c'} >>> lsh.remove('b') >>> lsh.nearby('a') # doctest: +SKIP set() :param float radius: Scaling parameter used in hash function. Determines the size of the neighbourhood. """ def __init__(self, radius): if not (isinstance(radius, Number) and radius > 0): raise ValueError( "radius must be float greater than 0, given: {}".format(radius) ) self._radius = radius self._hash_to_key = defaultdict(set) self._key_to_hash = {} def _hash(self, point): coords = (point / self._radius).round() return tuple(map(int, coords))
[docs] def add(self, key, point): """ Adds (``key``, ``point``) pair to the hash. :param key: Key used identify ``point``. :param torch.Tensor point: data, should be detached and on cpu. """ _hash = self._hash(point) if key in self._key_to_hash: self.remove(key) self._key_to_hash[key] = _hash self._hash_to_key[_hash].add(key)
[docs] def remove(self, key): """ Removes ``key`` and corresponding point from the hash. Raises :exc:`KeyError` if key is not in hash. :param key: key used to identify point. """ _hash = self._key_to_hash.pop(key) self._hash_to_key[_hash].remove(key)
[docs] def nearby(self, key): r""" Returns a set of keys which are neighbours of the point identified by ``key``. Two points are nearby if difference of each element of their hashes is smaller than 2. In euclidean space, this corresponds to all points :math:`\mathbf{p}` where :math:`|\mathbf{p}_k-(\mathbf{p_{key}})_k|<r`, and some points (all points not guaranteed) where :math:`|\mathbf{p}_k-(\mathbf{p_{key}})_k|<2r`. :param key: key used to identify input point. :return: a set of keys identifying neighbours of the input point. :rtype: set """ _hash = self._key_to_hash[key] result = set() for nearby_hash in itertools.product(*[[i - 1, i, i + 1] for i in _hash]): result |= self._hash_to_key[nearby_hash] result.remove(key) return result
[docs]class ApproxSet: """ Queries low-dimensional euclidean space for approximate occupancy. :param float radius: scaling parameter used in hash function. Determines the size of the bin. See :class:`LSH` for details. """ def __init__(self, radius): if not (isinstance(radius, Number) and radius > 0): raise ValueError( "radius must be float greater than 0, given: {}".format(radius) ) self._radius = radius self._bins = set() def _hash(self, point): coords = (point / self._radius).round() return tuple(map(int, coords))
[docs] def try_add(self, point): """ Attempts to add ``point`` to set. Only adds there are no points in the ``point``'s bin. :param torch.Tensor point: Point to be queried, should be detached and on cpu. :return: ``True`` if point is successfully added, ``False`` if there is already a point in ``point``'s bin. :rtype: bool """ _hash = self._hash(point) if _hash in self._bins: return False self._bins.add(_hash) return True
[docs]def merge_points(points, radius): """ Greedily merge points that are closer than given radius. This uses :class:`LSH` to achieve complexity that is linear in the number of merged clusters and quadratic in the size of the largest merged cluster. :param torch.Tensor points: A tensor of shape ``(K,D)`` where ``K`` is the number of points and ``D`` is the number of dimensions. :param float radius: The minimum distance nearer than which points will be merged. :return: A tuple ``(merged_points, groups)`` where ``merged_points`` is a tensor of shape ``(J,D)`` where ``J <= K``, and ``groups`` is a list of tuples of indices mapping merged points to original points. Note that ``len(groups) == J`` and ``sum(len(group) for group in groups) == K``. :rtype: tuple """ if points.dim() != 2: raise ValueError( "Expected points.shape == (K,D), but got {}".format(points.shape) ) if not (isinstance(radius, Number) and radius > 0): raise ValueError( "Expected radius to be a positive number, but got {}".format(radius) ) radius = ( 0.99 * radius ) # avoid merging points exactly radius apart, e.g. grid points threshold = radius**2 # setup data structures to cheaply search for nearest pairs lsh = LSH(radius) priority_queue = [] groups = [(i,) for i in range(len(points))] for i, point in enumerate(points): lsh.add(i, point) for j in lsh.nearby(i): d2 = (point - points[j]).pow(2).sum().item() if d2 < threshold: heapq.heappush(priority_queue, (d2, j, i)) if not priority_queue: return points, groups # convert from dense to sparse representation next_id = len(points) points = dict(enumerate(points)) groups = dict(enumerate(groups)) # greedily merge while priority_queue: d1, i, j = heapq.heappop(priority_queue) if i not in points or j not in points: continue k = next_id next_id += 1 points[k] = (points.pop(i) + points.pop(j)) / 2 groups[k] = groups.pop(i) + groups.pop(j) lsh.remove(i) lsh.remove(j) lsh.add(k, points[k]) for i in lsh.nearby(k): if i == k: continue d2 = (points[i] - points[k]).pow(2).sum().item() if d2 < threshold: heapq.heappush(priority_queue, (d2, i, k)) # convert from sparse to dense representation ids = sorted(points.keys()) points = torch.stack([points[i] for i in ids]) groups = [groups[i] for i in ids] return points, groups