Source code for pyro.distributions.torch_distribution

import warnings

import torch
from torch.distributions import constraints
from torch.distributions.kl import kl_divergence, register_kl

import pyro.distributions.torch
from pyro.distributions.distribution import Distribution
from pyro.distributions.util import broadcast_shape, scale_and_mask


[docs]class TorchDistributionMixin(Distribution): """ Mixin to provide Pyro compatibility for PyTorch distributions. You should instead use `TorchDistribution` for new distribution classes. This is mainly useful for wrapping existing PyTorch distributions for use in Pyro. Derived classes must first inherit from :class:`torch.distributions.distribution.Distribution` and then inherit from :class:`TorchDistributionMixin`. """
[docs] def __call__(self, sample_shape=torch.Size()): """ Samples a random value. This is reparameterized whenever possible, calling :meth:`~torch.distributions.distribution.Distribution.rsample` for reparameterized distributions and :meth:`~torch.distributions.distribution.Distribution.sample` for non-reparameterized distributions. :param sample_shape: the size of the iid batch to be drawn from the distribution. :type sample_shape: torch.Size :return: A random value or batch of random values (if parameters are batched). The shape of the result should be `self.shape()`. :rtype: torch.Tensor """ return self.rsample(sample_shape) if self.has_rsample else self.sample(sample_shape)
@property def event_dim(self): """ :return: Number of dimensions of individual events. :rtype: int """ return len(self.event_shape)
[docs] def shape(self, sample_shape=torch.Size()): """ The tensor shape of samples from this distribution. Samples are of shape:: d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape :param sample_shape: the size of the iid batch to be drawn from the distribution. :type sample_shape: torch.Size :return: Tensor shape of samples. :rtype: torch.Size """ return sample_shape + self.batch_shape + self.event_shape
[docs] def expand_by(self, sample_shape): """ Expands a distribution by adding ``sample_shape`` to the left side of its :attr:`~torch.distributions.distribution.Distribution.batch_shape`. To expand internal dims of ``self.batch_shape`` from 1 to something larger, use :meth:`expand` instead. :param torch.Size sample_shape: The size of the iid batch to be drawn from the distribution. :return: An expanded version of this distribution. :rtype: :class:`ReshapedDistribution` """ return self.expand(torch.Size(sample_shape) + self.batch_shape)
[docs] def reshape(self, sample_shape=None, extra_event_dims=None): raise Exception(''' .reshape(sample_shape=s, extra_event_dims=n) was renamed and split into .expand_by(sample_shape=s).to_event(reinterpreted_batch_ndims=n).''')
[docs] def to_event(self, reinterpreted_batch_ndims=None): """ Reinterprets the ``n`` rightmost dimensions of this distributions :attr:`~torch.distributions.distribution.Distribution.batch_shape` as event dims, adding them to the left side of :attr:`~torch.distributions.distribution.Distribution.event_shape`. Example: .. doctest:: :hide: >>> d0 = dist.Normal(torch.zeros(2, 3, 4, 5), torch.ones(2, 3, 4, 5)) >>> [d0.batch_shape, d0.event_shape] [torch.Size([2, 3, 4, 5]), torch.Size([])] >>> d1 = d0.to_event(2) >>> [d1.batch_shape, d1.event_shape] [torch.Size([2, 3]), torch.Size([4, 5])] >>> d2 = d1.to_event(1) >>> [d2.batch_shape, d2.event_shape] [torch.Size([2]), torch.Size([3, 4, 5])] >>> d3 = d1.to_event(2) >>> [d3.batch_shape, d3.event_shape] [torch.Size([]), torch.Size([2, 3, 4, 5])] :param int reinterpreted_batch_ndims: The number of batch dimensions to reinterpret as event dimensions. :return: A reshaped version of this distribution. :rtype: :class:`pyro.distributions.torch.Independent` """ if reinterpreted_batch_ndims is None: reinterpreted_batch_ndims = len(self.batch_shape) return pyro.distributions.torch.Independent(self, reinterpreted_batch_ndims)
[docs] def independent(self, reinterpreted_batch_ndims=None): warnings.warn("independent is deprecated; use to_event instead", DeprecationWarning) return self.to_event(reinterpreted_batch_ndims=reinterpreted_batch_ndims)
[docs] def mask(self, mask): """ Masks a distribution by a zero-one tensor that is broadcastable to the distributions :attr:`~torch.distributions.distribution.Distribution.batch_shape`. :param torch.Tensor mask: A zero-one valued float tensor. :return: A masked copy of this distribution. :rtype: :class:`MaskedDistribution` """ return MaskedDistribution(self, mask)
[docs]class TorchDistribution(torch.distributions.Distribution, TorchDistributionMixin): """ Base class for PyTorch-compatible distributions with Pyro support. This should be the base class for almost all new Pyro distributions. .. note:: Parameters and data should be of type :class:`~torch.Tensor` and all methods return type :class:`~torch.Tensor` unless otherwise noted. **Tensor Shapes**: TorchDistributions provide a method ``.shape()`` for the tensor shape of samples:: x = d.sample(sample_shape) assert x.shape == d.shape(sample_shape) Pyro follows the same distribution shape semantics as PyTorch. It distinguishes between three different roles for tensor shapes of samples: - *sample shape* corresponds to the shape of the iid samples drawn from the distribution. This is taken as an argument by the distribution's `sample` method. - *batch shape* corresponds to non-identical (independent) parameterizations of the distribution, inferred from the distribution's parameter shapes. This is fixed for a distribution instance. - *event shape* corresponds to the event dimensions of the distribution, which is fixed for a distribution class. These are collapsed when we try to score a sample from the distribution via `d.log_prob(x)`. These shapes are related by the equation:: assert d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape Distributions provide a vectorized :meth:`~torch.distributions.distribution.Distribution.log_prob` method that evaluates the log probability density of each event in a batch independently, returning a tensor of shape ``sample_shape + d.batch_shape``:: x = d.sample(sample_shape) assert x.shape == d.shape(sample_shape) log_p = d.log_prob(x) assert log_p.shape == sample_shape + d.batch_shape **Implementing New Distributions**: Derived classes must implement the methods :meth:`~torch.distributions.distribution.Distribution.sample` (or :meth:`~torch.distributions.distribution.Distribution.rsample` if ``.has_rsample == True``) and :meth:`~torch.distributions.distribution.Distribution.log_prob`, and must implement the properties :attr:`~torch.distributions.distribution.Distribution.batch_shape`, and :attr:`~torch.distributions.distribution.Distribution.event_shape`. Discrete classes may also implement the :meth:`~torch.distributions.distribution.Distribution.enumerate_support` method to improve gradient estimates and set ``.has_enumerate_support = True``. """ pass
class MaskedDistribution(TorchDistribution): """ Masks a distribution by a zero-one tensor that is broadcastable to the distribution's :attr:`~torch.distributions.distribution.Distribution.batch_shape`. :param torch.Tensor mask: A zero-one valued tensor. """ arg_constraints = {} def __init__(self, base_dist, mask): if broadcast_shape(mask.shape, base_dist.batch_shape) != base_dist.batch_shape: raise ValueError("Expected mask.shape to be broadcastable to base_dist.batch_shape, " "actual {} vs {}".format(mask.shape, base_dist.batch_shape)) self.base_dist = base_dist self._mask = mask.bool() super(MaskedDistribution, self).__init__(base_dist.batch_shape, base_dist.event_shape) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(MaskedDistribution, _instance) batch_shape = torch.Size(batch_shape) new.base_dist = self.base_dist.expand(batch_shape) new._mask = self._mask.expand(batch_shape) super(MaskedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new @property def has_rsample(self): return self.base_dist.has_rsample @property def has_enumerate_support(self): return self.base_dist.has_enumerate_support @constraints.dependent_property def support(self): return self.base_dist.support def sample(self, sample_shape=torch.Size()): return self.base_dist.sample(sample_shape) def rsample(self, sample_shape=torch.Size()): return self.base_dist.rsample(sample_shape) def log_prob(self, value): return scale_and_mask(self.base_dist.log_prob(value), mask=self._mask) def score_parts(self, value): return self.base_dist.score_parts(value).scale_and_mask(mask=self._mask) def enumerate_support(self, expand=True): return self.base_dist.enumerate_support(expand=expand) @property def mean(self): return self.base_dist.mean @property def variance(self): return self.base_dist.variance @register_kl(MaskedDistribution, MaskedDistribution) def _kl_masked_masked(p, q): mask = p._mask if p._mask is q._mask else p._mask & q._mask kl = kl_divergence(p.base_dist, q.base_dist) return scale_and_mask(kl, mask=mask)