Source code for pyro.ops.indexing

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch


def _is_batched(arg):
    return isinstance(arg, torch.Tensor) and arg.dim()


def _flatten(args, out):
    if isinstance(args, tuple):
        for arg in args:
            _flatten(arg, out)
    else:
        # Combine consecutive Ellipsis.
        if args is Ellipsis and out and out[-1] is Ellipsis:
            return
        out.append(args)


[docs]def index(tensor, args): """ Indexing with nested tuples. See also the convenience wrapper :class:`Index`. This is useful for writing indexing code that is compatible with multiple interpretations, e.g. scalar evaluation, vectorized evaluation, or reshaping. For example suppose ``x`` is a parameter with ``x.dim() == 2`` and we wish to generalize the expression ``x[..., t]`` where ``t`` can be any of: - a scalar ``t=1`` as in ``x[..., 1]``; - a slice ``t=slice(None)`` equivalent to ``x[..., :]``; or - a reshaping operation ``t=(Ellipsis, None)`` equivalent to ``x.unsqueeze(-1)``. While naive indexing would work for the first two , the third example would result in a nested tuple ``(Ellipsis, (Ellipsis, None))``. This helper flattens that nested tuple and combines consecutive ``Ellipsis``. :param torch.Tensor tensor: A tensor to be indexed. :param tuple args: An index, as args to ``__getitem__``. :returns: A flattened interpetation of ``tensor[args]``. :rtype: torch.Tensor """ if not isinstance(args, tuple): return tensor[args] if not args: return tensor # Flatten. flat = [] _flatten(args, flat) args = tuple(flat) return tensor[args]
[docs]class Index: """ Convenience wrapper around :func:`index`. The following are equivalent:: Index(x)[..., i, j, :] index(x, (Ellipsis, i, j, slice(None))) :param torch.Tensor tensor: A tensor to be indexed. :return: An object with a special :meth:`__getitem__` method. """ def __init__(self, tensor): self._tensor = tensor def __getitem__(self, args): return index(self._tensor, args)
[docs]def vindex(tensor, args): """ Vectorized advanced indexing with broadcasting semantics. See also the convenience wrapper :class:`Vindex`. This is useful for writing indexing code that is compatible with batching and enumeration, especially for selecting mixture components with discrete random variables. For example suppose ``x`` is a parameter with ``x.dim() == 3`` and we wish to generalize the expression ``x[i, :, j]`` from integer ``i,j`` to tensors ``i,j`` with batch dims and enum dims (but no event dims). Then we can write the generalize version using :class:`Vindex` :: xij = Vindex(x)[i, :, j] batch_shape = broadcast_shape(i.shape, j.shape) event_shape = (x.size(1),) assert xij.shape == batch_shape + event_shape To handle the case when ``x`` may also contain batch dimensions (e.g. if ``x`` was sampled in a plated context as when using vectorized particles), :func:`vindex` uses the special convention that ``Ellipsis`` denotes batch dimensions (hence ``...`` can appear only on the left, never in the middle or in the right). Suppose ``x`` has event dim 3. Then we can write:: old_batch_shape = x.shape[:-3] old_event_shape = x.shape[-3:] xij = Vindex(x)[..., i, :, j] # The ... denotes unknown batch shape. new_batch_shape = broadcast_shape(old_batch_shape, i.shape, j.shape) new_event_shape = (x.size(1),) assert xij.shape = new_batch_shape + new_event_shape Note that this special handling of ``Ellipsis`` differs from the NEP [1]. Formally, this function assumes: 1. Each arg is either ``Ellipsis``, ``slice(None)``, an integer, or a batched ``torch.LongTensor`` (i.e. with empty event shape). This function does not support Nontrivial slices or ``torch.BoolTensor`` masks. ``Ellipsis`` can only appear on the left as ``args[0]``. 2. If ``args[0] is not Ellipsis`` then ``tensor`` is not batched, and its event dim is equal to ``len(args)``. 3. If ``args[0] is Ellipsis`` then ``tensor`` is batched and its event dim is equal to ``len(args[1:])``. Dims of ``tensor`` to the left of the event dims are considered batch dims and will be broadcasted with dims of tensor args. Note that if none of the args is a tensor with ``.dim() > 0``, then this function behaves like standard indexing:: if not any(isinstance(a, torch.Tensor) and a.dim() for a in args): assert Vindex(x)[args] == x[args] **References** [1] https://www.numpy.org/neps/nep-0021-advanced-indexing.html introduces ``vindex`` as a helper for vectorized indexing. The Pyro implementation is similar to the proposed notation ``x.vindex[]`` except for slightly different handling of ``Ellipsis``. :param torch.Tensor tensor: A tensor to be indexed. :param tuple args: An index, as args to ``__getitem__``. :returns: A nonstandard interpetation of ``tensor[args]``. :rtype: torch.Tensor """ if not isinstance(args, tuple): return tensor[args] if not args: return tensor # Compute event dim before and after indexing. if args[0] is Ellipsis: args = args[1:] if not args: return tensor old_event_dim = len(args) args = (slice(None),) * (tensor.dim() - len(args)) + args else: args = args + (slice(None),) * (tensor.dim() - len(args)) old_event_dim = len(args) assert len(args) == tensor.dim() if any(a is Ellipsis for a in args): raise NotImplementedError("Non-leading Ellipsis is not supported") # In simple cases, standard advanced indexing broadcasts correctly. is_standard = True if tensor.dim() > old_event_dim and _is_batched(args[0]): is_standard = False elif any(_is_batched(a) for a in args[1:]): is_standard = False if is_standard: return tensor[args] # Convert args to use broadcasting semantics. new_event_dim = sum(isinstance(a, slice) for a in args[-old_event_dim:]) new_dim = 0 args = list(args) for i, arg in reversed(list(enumerate(args))): if isinstance(arg, slice): # Convert slices to torch.arange()s. if arg != slice(None): raise NotImplementedError("Nontrivial slices are not supported") arg = torch.arange(tensor.size(i), dtype=torch.long, device=tensor.device) arg = arg.reshape((-1,) + (1,) * new_dim) new_dim += 1 elif _is_batched(arg): # Reshape nontrivial tensors. arg = arg.reshape(arg.shape + (1,) * new_event_dim) args[i] = arg args = tuple(args) return tensor[args]
[docs]class Vindex: """ Convenience wrapper around :func:`vindex`. The following are equivalent:: Vindex(x)[..., i, j, :] vindex(x, (Ellipsis, i, j, slice(None))) :param torch.Tensor tensor: A tensor to be indexed. :return: An object with a special :meth:`__getitem__` method. """ def __init__(self, tensor): self._tensor = tensor def __getitem__(self, args): return vindex(self._tensor, args)