Source code for pyro.distributions.improper_uniform

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints

from .torch_distribution import TorchDistribution
from .util import broadcast_shape

[docs]class ImproperUniform(TorchDistribution): """ Improper distribution with zero :meth:`log_prob` and undefined :meth:`sample`. This is useful for transforming a model from generative dag form to factor graph form for use in HMC. For example the following are equal in distribution:: # Version 1. a generative dag x = pyro.sample("x", Normal(0, 1)) y = pyro.sample("y", Normal(x, 1)) z = pyro.sample("z", Normal(y, 1)) # Version 2. a factor graph xyz = pyro.sample("xyz", ImproperUniform(constraints.real, (), (3,))) x, y, z = xyz.unbind(-1) pyro.sample("x", Normal(0, 1), obs=x) pyro.sample("y", Normal(x, 1), obs=y) pyro.sample("z", Normal(y, 1), obs=z) Note this distribution errors when :meth:`sample` is called. To create a similar distribution that instead samples from a specified distribution consider using ``.mask(False)`` as in:: xyz = dist.Normal(0, 1).expand([3]).to_event(1).mask(False) :param support: The support of the distribution. :type support: ~torch.distributions.constraints.Constraint :param torch.Size batch_shape: The batch shape. :param torch.Size event_shape: The event shape. """ arg_constraints = {} def __init__(self, support, batch_shape, event_shape): assert isinstance(support, constraints.Constraint) self._support = support super().__init__(batch_shape, event_shape) @constraints.dependent_property def support(self): return self._support
[docs] def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) new = self._get_checked_instance(ImproperUniform, _instance) new._support = self._support super(ImproperUniform, new).__init__(batch_shape, self.event_shape) return new
[docs] def log_prob(self, value): batch_shape = value.shape[:value.dim() - self.event_dim] batch_shape = broadcast_shape(batch_shape, self.batch_shape) return torch.zeros(()).expand(batch_shape)
[docs] def sample(self, sample_shape=torch.Size()): raise NotImplementedError("ImproperUniform does not support sampling")