import warnings
import torch
from torch.distributions import constraints
from torch.distributions.kl import kl_divergence, register_kl
import pyro.distributions.torch
from pyro.distributions.distribution import Distribution
from pyro.distributions.util import broadcast_shape, scale_and_mask
[docs]class TorchDistributionMixin(Distribution):
"""
Mixin to provide Pyro compatibility for PyTorch distributions.
You should instead use `TorchDistribution` for new distribution classes.
This is mainly useful for wrapping existing PyTorch distributions for
use in Pyro. Derived classes must first inherit from
:class:`torch.distributions.distribution.Distribution` and then inherit
from :class:`TorchDistributionMixin`.
"""
[docs] def __call__(self, sample_shape=torch.Size()):
"""
Samples a random value.
This is reparameterized whenever possible, calling
:meth:`~torch.distributions.distribution.Distribution.rsample` for
reparameterized distributions and
:meth:`~torch.distributions.distribution.Distribution.sample` for
non-reparameterized distributions.
:param sample_shape: the size of the iid batch to be drawn from the
distribution.
:type sample_shape: torch.Size
:return: A random value or batch of random values (if parameters are
batched). The shape of the result should be `self.shape()`.
:rtype: torch.Tensor
"""
return self.rsample(sample_shape) if self.has_rsample else self.sample(sample_shape)
@property
def event_dim(self):
"""
:return: Number of dimensions of individual events.
:rtype: int
"""
return len(self.event_shape)
[docs] def shape(self, sample_shape=torch.Size()):
"""
The tensor shape of samples from this distribution.
Samples are of shape::
d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
:param sample_shape: the size of the iid batch to be drawn from the
distribution.
:type sample_shape: torch.Size
:return: Tensor shape of samples.
:rtype: torch.Size
"""
return sample_shape + self.batch_shape + self.event_shape
[docs] def expand_by(self, sample_shape):
"""
Expands a distribution by adding ``sample_shape`` to the left side of
its :attr:`~torch.distributions.distribution.Distribution.batch_shape`.
To expand internal dims of ``self.batch_shape`` from 1 to something
larger, use :meth:`expand` instead.
:param torch.Size sample_shape: The size of the iid batch to be drawn
from the distribution.
:return: An expanded version of this distribution.
:rtype: :class:`ReshapedDistribution`
"""
return self.expand(torch.Size(sample_shape) + self.batch_shape)
[docs] def reshape(self, sample_shape=None, extra_event_dims=None):
raise Exception('''
.reshape(sample_shape=s, extra_event_dims=n) was renamed and split into
.expand_by(sample_shape=s).to_event(reinterpreted_batch_ndims=n).''')
[docs] def to_event(self, reinterpreted_batch_ndims=None):
"""
Reinterprets the ``n`` rightmost dimensions of this distributions
:attr:`~torch.distributions.distribution.Distribution.batch_shape`
as event dims, adding them to the left side of
:attr:`~torch.distributions.distribution.Distribution.event_shape`.
Example:
.. doctest::
:hide:
>>> d0 = dist.Normal(torch.zeros(2, 3, 4, 5), torch.ones(2, 3, 4, 5))
>>> [d0.batch_shape, d0.event_shape]
[torch.Size([2, 3, 4, 5]), torch.Size([])]
>>> d1 = d0.to_event(2)
>>> [d1.batch_shape, d1.event_shape]
[torch.Size([2, 3]), torch.Size([4, 5])]
>>> d2 = d1.to_event(1)
>>> [d2.batch_shape, d2.event_shape]
[torch.Size([2]), torch.Size([3, 4, 5])]
>>> d3 = d1.to_event(2)
>>> [d3.batch_shape, d3.event_shape]
[torch.Size([]), torch.Size([2, 3, 4, 5])]
:param int reinterpreted_batch_ndims: The number of batch dimensions
to reinterpret as event dimensions.
:return: A reshaped version of this distribution.
:rtype: :class:`pyro.distributions.torch.Independent`
"""
if reinterpreted_batch_ndims is None:
reinterpreted_batch_ndims = len(self.batch_shape)
return pyro.distributions.torch.Independent(self, reinterpreted_batch_ndims)
[docs] def independent(self, reinterpreted_batch_ndims=None):
warnings.warn("independent is deprecated; use to_event instead", DeprecationWarning)
return self.to_event(reinterpreted_batch_ndims=reinterpreted_batch_ndims)
[docs] def mask(self, mask):
"""
Masks a distribution by a boolean or boolean-valued tensor that is
broadcastable to the distributions
:attr:`~torch.distributions.distribution.Distribution.batch_shape` .
:param mask: A boolean or boolean valued tensor.
:type mask: bool or torch.Tensor
:return: A masked copy of this distribution.
:rtype: :class:`MaskedDistribution`
"""
return MaskedDistribution(self, mask)
[docs]class TorchDistribution(torch.distributions.Distribution, TorchDistributionMixin):
"""
Base class for PyTorch-compatible distributions with Pyro support.
This should be the base class for almost all new Pyro distributions.
.. note::
Parameters and data should be of type :class:`~torch.Tensor`
and all methods return type :class:`~torch.Tensor` unless
otherwise noted.
**Tensor Shapes**:
TorchDistributions provide a method ``.shape()`` for the tensor shape of samples::
x = d.sample(sample_shape)
assert x.shape == d.shape(sample_shape)
Pyro follows the same distribution shape semantics as PyTorch. It distinguishes
between three different roles for tensor shapes of samples:
- *sample shape* corresponds to the shape of the iid samples drawn from the distribution.
This is taken as an argument by the distribution's `sample` method.
- *batch shape* corresponds to non-identical (independent) parameterizations of
the distribution, inferred from the distribution's parameter shapes. This is
fixed for a distribution instance.
- *event shape* corresponds to the event dimensions of the distribution, which
is fixed for a distribution class. These are collapsed when we try to score
a sample from the distribution via `d.log_prob(x)`.
These shapes are related by the equation::
assert d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
Distributions provide a vectorized
:meth:`~torch.distributions.distribution.Distribution.log_prob` method that
evaluates the log probability density of each event in a batch
independently, returning a tensor of shape
``sample_shape + d.batch_shape``::
x = d.sample(sample_shape)
assert x.shape == d.shape(sample_shape)
log_p = d.log_prob(x)
assert log_p.shape == sample_shape + d.batch_shape
**Implementing New Distributions**:
Derived classes must implement the methods
:meth:`~torch.distributions.distribution.Distribution.sample`
(or :meth:`~torch.distributions.distribution.Distribution.rsample` if
``.has_rsample == True``) and
:meth:`~torch.distributions.distribution.Distribution.log_prob`, and must
implement the properties
:attr:`~torch.distributions.distribution.Distribution.batch_shape`,
and :attr:`~torch.distributions.distribution.Distribution.event_shape`.
Discrete classes may also implement the
:meth:`~torch.distributions.distribution.Distribution.enumerate_support`
method to improve gradient estimates and set
``.has_enumerate_support = True``.
"""
pass
class MaskedDistribution(TorchDistribution):
"""
Masks a distribution by a boolean tensor that is broadcastable to the
distribution's :attr:`~torch.distributions.distribution.Distribution.batch_shape`.
In the special case ``mask is False``, computation of :meth:`log_prob` ,
:meth:`score_parts` , and ``kl_divergence()`` is skipped, and constant zero
values are returned instead.
:param mask: A boolean or boolean-valued tensor.
:type mask: torch.Tensor or bool
"""
arg_constraints = {}
def __init__(self, base_dist, mask):
self.base_dist = base_dist
if isinstance(mask, bool):
self._mask = mask
else:
if broadcast_shape(mask.shape, base_dist.batch_shape) != base_dist.batch_shape:
raise ValueError("Expected mask.shape to be broadcastable to base_dist.batch_shape, "
"actual {} vs {}".format(mask.shape, base_dist.batch_shape))
self._mask = mask.bool()
super(MaskedDistribution, self).__init__(base_dist.batch_shape, base_dist.event_shape)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(MaskedDistribution, _instance)
batch_shape = torch.Size(batch_shape)
new.base_dist = self.base_dist.expand(batch_shape)
new._mask = self._mask
if isinstance(new._mask, torch.Tensor):
new._mask = new._mask.expand(batch_shape)
super(MaskedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
@property
def has_rsample(self):
return self.base_dist.has_rsample
@property
def has_enumerate_support(self):
return self.base_dist.has_enumerate_support
@constraints.dependent_property
def support(self):
return self.base_dist.support
def sample(self, sample_shape=torch.Size()):
return self.base_dist.sample(sample_shape)
def rsample(self, sample_shape=torch.Size()):
return self.base_dist.rsample(sample_shape)
def log_prob(self, value):
if self._mask is False:
shape = broadcast_shape(self.base_dist.batch_shape,
value.shape[:value.dim() - self.event_dim])
return torch.zeros((), device=value.device).expand(shape)
if self._mask is True:
return self.base_dist.log_prob(value)
return scale_and_mask(self.base_dist.log_prob(value), mask=self._mask)
def score_parts(self, value):
if isinstance(self._mask, bool):
return super().score_parts(value) # calls self.log_prob(value)
return self.base_dist.score_parts(value).scale_and_mask(mask=self._mask)
def enumerate_support(self, expand=True):
return self.base_dist.enumerate_support(expand=expand)
@property
def mean(self):
return self.base_dist.mean
@property
def variance(self):
return self.base_dist.variance
@register_kl(MaskedDistribution, MaskedDistribution)
def _kl_masked_masked(p, q):
if p._mask is False or q._mask is False:
mask = False
elif p._mask is True:
mask = q._mask
elif q._mask is True:
mask = p._mask
elif p._mask is q._mask:
mask = p._mask
else:
mask = p._mask & q._mask
if mask is False:
return 0. # Return a float, since we cannot determine device.
if mask is True:
return kl_divergence(p.base_dist, q.base_dist)
kl = kl_divergence(p.base_dist, q.base_dist)
return scale_and_mask(kl, mask=mask)