Source code for pyro.distributions.delta

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import numbers

import torch

from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import sum_rightmost

from . import constraints


[docs]class Delta(TorchDistribution): """ Degenerate discrete distribution (a single point). Discrete distribution that assigns probability one to the single element in its support. Delta distribution parameterized by a random choice should not be used with MCMC based inference, as doing so produces incorrect results. :param torch.Tensor v: The single support element. :param torch.Tensor log_density: An optional density for this Delta. This is useful to keep the class of :class:`Delta` distributions closed under differentiable transformation. :param int event_dim: Optional event dimension, defaults to zero. """ has_rsample = True arg_constraints = {"v": constraints.dependent, "log_density": constraints.real} def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None): if event_dim > v.dim(): raise ValueError( "Expected event_dim <= v.dim(), actual {} vs {}".format( event_dim, v.dim() ) ) batch_dim = v.dim() - event_dim batch_shape = v.shape[:batch_dim] event_shape = v.shape[batch_dim:] if isinstance(log_density, numbers.Number): log_density = torch.full( batch_shape, log_density, dtype=v.dtype, device=v.device ) elif validate_args and log_density.shape != batch_shape: raise ValueError( "Expected log_density.shape = {}, actual {}".format( log_density.shape, batch_shape ) ) self.v = v self.log_density = log_density super().__init__(batch_shape, event_shape, validate_args=validate_args) @constraints.dependent_property def support(self): return constraints.independent(constraints.real, self.event_dim)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Delta, _instance) batch_shape = torch.Size(batch_shape) new.v = self.v.expand(batch_shape + self.event_shape) new.log_density = self.log_density.expand(batch_shape) super(Delta, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new
[docs] def rsample(self, sample_shape=torch.Size()): shape = sample_shape + self.v.shape return self.v.expand(shape)
[docs] def log_prob(self, x): v = self.v.expand(self.shape()) log_prob = (x == v).type(x.dtype).log() log_prob = sum_rightmost(log_prob, self.event_dim) return log_prob + self.log_density
@property def mean(self): return self.v @property def variance(self): return torch.zeros_like(self.v)