Source code for pyro.distributions.unit

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints

from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape


[docs]class Unit(TorchDistribution): """ Trivial nonnormalized distribution representing the unit type. The unit type has a single value with no data, i.e. ``value.numel() == 0``. This is used for :func:`pyro.factor` statements. """ arg_constraints = {"log_factor": constraints.real} support = constraints.real def __init__(self, log_factor, *, has_rsample=None, validate_args=None): log_factor = torch.as_tensor(log_factor) batch_shape = log_factor.shape event_shape = torch.Size((0,)) # This satisfies .numel() == 0. self.log_factor = log_factor if has_rsample is not None: self.has_rsample = has_rsample super().__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) new = self._get_checked_instance(Unit, _instance) new.log_factor = self.log_factor.expand(batch_shape) if "has_rsample" in self.__dict__: new.has_rsample = self.has_rsample super(Unit, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new
[docs] def sample(self, sample_shape=torch.Size()): return self.log_factor.new_empty(sample_shape + self.shape())
[docs] def rsample(self, sample_shape=torch.Size()): return self.log_factor.new_empty(sample_shape + self.shape())
[docs] def log_prob(self, value): shape = broadcast_shape(self.batch_shape, value.shape[:-1]) return self.log_factor.expand(shape)