# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import heapq
import itertools
from collections import defaultdict
from numbers import Number
import torch
[docs]class LSH:
"""
Implements locality-sensitive hashing for low-dimensional euclidean space.
Allows to efficiently find neighbours of a point. Provides 2 guarantees:
- Difference between coordinates of points not returned by :meth:`nearby`
and input point is larger than ``radius``.
- Difference between coordinates of points returned by :meth:`nearby` and
input point is smaller than 2 ``radius``.
Example:
>>> radius = 1
>>> lsh = LSH(radius)
>>> a = torch.tensor([-0.51, -0.51]) # hash(a)=(-1,-1)
>>> b = torch.tensor([-0.49, -0.49]) # hash(a)=(0,0)
>>> c = torch.tensor([1.0, 1.0]) # hash(b)=(1,1)
>>> lsh.add('a', a)
>>> lsh.add('b', b)
>>> lsh.add('c', c)
>>> # even though c is within 2radius of a
>>> lsh.nearby('a') # doctest: +SKIP
{'b'}
>>> lsh.nearby('b') # doctest: +SKIP
{'a', 'c'}
>>> lsh.remove('b')
>>> lsh.nearby('a') # doctest: +SKIP
set()
:param float radius: Scaling parameter used in hash function. Determines the size of the neighbourhood.
"""
def __init__(self, radius):
if not (isinstance(radius, Number) and radius > 0):
raise ValueError(
"radius must be float greater than 0, given: {}".format(radius)
)
self._radius = radius
self._hash_to_key = defaultdict(set)
self._key_to_hash = {}
def _hash(self, point):
coords = (point / self._radius).round()
return tuple(map(int, coords))
[docs] def add(self, key, point):
"""
Adds (``key``, ``point``) pair to the hash.
:param key: Key used identify ``point``.
:param torch.Tensor point: data, should be detached and on cpu.
"""
_hash = self._hash(point)
if key in self._key_to_hash:
self.remove(key)
self._key_to_hash[key] = _hash
self._hash_to_key[_hash].add(key)
[docs] def remove(self, key):
"""
Removes ``key`` and corresponding point from the hash.
Raises :exc:`KeyError` if key is not in hash.
:param key: key used to identify point.
"""
_hash = self._key_to_hash.pop(key)
self._hash_to_key[_hash].remove(key)
[docs] def nearby(self, key):
r"""
Returns a set of keys which are neighbours of the point identified by ``key``.
Two points are nearby if difference of each element of their hashes is smaller than 2. In euclidean space, this
corresponds to all points :math:`\mathbf{p}` where :math:`|\mathbf{p}_k-(\mathbf{p_{key}})_k|<r`,
and some points (all points not guaranteed) where :math:`|\mathbf{p}_k-(\mathbf{p_{key}})_k|<2r`.
:param key: key used to identify input point.
:return: a set of keys identifying neighbours of the input point.
:rtype: set
"""
_hash = self._key_to_hash[key]
result = set()
for nearby_hash in itertools.product(*[[i - 1, i, i + 1] for i in _hash]):
result |= self._hash_to_key[nearby_hash]
result.remove(key)
return result
[docs]class ApproxSet:
"""
Queries low-dimensional euclidean space for approximate occupancy.
:param float radius: scaling parameter used in hash function. Determines the size of the bin.
See :class:`LSH` for details.
"""
def __init__(self, radius):
if not (isinstance(radius, Number) and radius > 0):
raise ValueError(
"radius must be float greater than 0, given: {}".format(radius)
)
self._radius = radius
self._bins = set()
def _hash(self, point):
coords = (point / self._radius).round()
return tuple(map(int, coords))
[docs] def try_add(self, point):
"""
Attempts to add ``point`` to set. Only adds there are no points in the ``point``'s bin.
:param torch.Tensor point: Point to be queried, should be detached and on cpu.
:return: ``True`` if point is successfully added, ``False`` if there is already a point in ``point``'s bin.
:rtype: bool
"""
_hash = self._hash(point)
if _hash in self._bins:
return False
self._bins.add(_hash)
return True
[docs]def merge_points(points, radius):
"""
Greedily merge points that are closer than given radius.
This uses :class:`LSH` to achieve complexity that is linear in the number
of merged clusters and quadratic in the size of the largest merged cluster.
:param torch.Tensor points: A tensor of shape ``(K,D)`` where ``K`` is
the number of points and ``D`` is the number of dimensions.
:param float radius: The minimum distance nearer than which
points will be merged.
:return: A tuple ``(merged_points, groups)`` where ``merged_points`` is a
tensor of shape ``(J,D)`` where ``J <= K``, and ``groups`` is a list of
tuples of indices mapping merged points to original points. Note that
``len(groups) == J`` and ``sum(len(group) for group in groups) == K``.
:rtype: tuple
"""
if points.dim() != 2:
raise ValueError(
"Expected points.shape == (K,D), but got {}".format(points.shape)
)
if not (isinstance(radius, Number) and radius > 0):
raise ValueError(
"Expected radius to be a positive number, but got {}".format(radius)
)
radius = (
0.99 * radius
) # avoid merging points exactly radius apart, e.g. grid points
threshold = radius ** 2
# setup data structures to cheaply search for nearest pairs
lsh = LSH(radius)
priority_queue = []
groups = [(i,) for i in range(len(points))]
for i, point in enumerate(points):
lsh.add(i, point)
for j in lsh.nearby(i):
d2 = (point - points[j]).pow(2).sum().item()
if d2 < threshold:
heapq.heappush(priority_queue, (d2, j, i))
if not priority_queue:
return points, groups
# convert from dense to sparse representation
next_id = len(points)
points = dict(enumerate(points))
groups = dict(enumerate(groups))
# greedily merge
while priority_queue:
d1, i, j = heapq.heappop(priority_queue)
if i not in points or j not in points:
continue
k = next_id
next_id += 1
points[k] = (points.pop(i) + points.pop(j)) / 2
groups[k] = groups.pop(i) + groups.pop(j)
lsh.remove(i)
lsh.remove(j)
lsh.add(k, points[k])
for i in lsh.nearby(k):
if i == k:
continue
d2 = (points[i] - points[k]).pow(2).sum().item()
if d2 < threshold:
heapq.heappush(priority_queue, (d2, i, k))
# convert from sparse to dense representation
ids = sorted(points.keys())
points = torch.stack([points[i] for i in ids])
groups = [groups[i] for i in ids]
return points, groups