# Source code for pyro.distributions.one_two_matching

# Copyright Contributors to the Pyro project.

import logging
import math
import warnings

import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property

from .torch import Categorical
from .torch_distribution import TorchDistribution

logger = logging.getLogger(__name__)

class OneTwoMatchingConstraint(constraints.Constraint):
def __init__(self, num_destins):
self.num_destins = num_destins
self.num_sources = 2 * num_destins

def check(self, value):
if value.dim() == 0:
warnings.warn("Invalid event_shape: ()")
batch_shape, event_shape = value.shape[:-1], value.shape[-1:]
if event_shape != (self.num_sources,):
warnings.warn("Invalid event_shape: {}".format(event_shape))
if value.min() < 0 or value.max() >= self.num_destins:
warnings.warn("Value out of bounds")
counts = torch.zeros(batch_shape + (self.num_destins,))
[docs]class OneTwoMatching(TorchDistribution): r""" Random matching from 2*N sources to N destinations where each source matches exactly **one** destination and each destination matches exactly **two** sources. Samples are represented as long tensors of shape (2*N,) taking values in {0,...,N-1} and satisfying the above one-two constraint. The log probability of a sample v is the sum of edge logits, up to the log partition function log Z: .. math:: \log p(v) = \sum_s \text{logits}[s, v[s]] - \log Z Exact computations are expensive. To enable tractable approximations, set a number of belief propagation iterations via the bp_iters argument. The :meth:log_partition_function and :meth:log_prob methods use a Bethe approximation [1,2,3,4]. **References:** [1] Michael Chertkov, Lukas Kroc, Massimo Vergassola (2008) "Belief propagation and beyond for particle tracking" https://arxiv.org/pdf/0806.1199.pdf [2] Bert Huang, Tony Jebara (2009) "Approximating the Permanent with Belief Propagation" https://arxiv.org/pdf/0908.1769.pdf [3] Pascal O. Vontobel (2012) "The Bethe Permanent of a Non-Negative Matrix" https://arxiv.org/pdf/1107.4196.pdf [4] M Chertkov, AB Yedidia (2013) "Approximating the permanent with fractional belief propagation" http://www.jmlr.org/papers/volume14/chertkov13a/chertkov13a.pdf :param Tensor logits: An (2 * N, N)-shaped tensor of edge logits. :param int bp_iters: Optional number of belief propagation iterations. If unspecified or None expensive exact algorithms will be used. """ arg_constraints = {"logits": constraints.real} has_enumerate_support = True def __init__(self, logits, *, bp_iters=None, validate_args=None): if logits.dim() != 2: raise NotImplementedError("OneTwoMatching does not support batching") assert bp_iters is None or isinstance(bp_iters, int) and bp_iters > 0 self.num_sources, self.num_destins = logits.shape assert self.num_sources == 2 * self.num_destins self.logits = logits batch_shape = () event_shape = (self.num_sources,) super().__init__(batch_shape, event_shape, validate_args=validate_args) self.bp_iters = bp_iters @constraints.dependent_property def support(self): return OneTwoMatchingConstraint(self.num_destins)
[docs] def mode(self): """ Computes a maximum probability matching. .. note:: This requires the lap <https://pypi.org/project/lap/>_ package and runs on CPU. """ return maximum_weight_matching(self.logits)