# Source code for pyro.distributions.torch_distribution

```
from __future__ import absolute_import, division, print_function
import numbers
import torch
from torch.distributions import constraints
from pyro.distributions.distribution import Distribution
from pyro.distributions.score_parts import ScoreParts
from pyro.distributions.util import broadcast_shape, sum_rightmost
[docs]class TorchDistributionMixin(Distribution):
"""
Mixin to provide Pyro compatibility for PyTorch distributions.
You should instead use `TorchDistribution` for new distribution classes.
This is mainly useful for wrapping existing PyTorch distributions for
use in Pyro. Derived classes must first inherit from
:class:`torch.distributions.distribution.Distribution` and then inherit
from :class:`TorchDistributionMixin`.
"""
[docs] def __call__(self, sample_shape=torch.Size()):
"""
Samples a random value.
This is reparameterized whenever possible, calling
:meth:`~torch.distributions.distribution.Distribution.rsample` for
reparameterized distributions and
:meth:`~torch.distributions.distribution.Distribution.sample` for
non-reparameterized distributions.
:param sample_shape: the size of the iid batch to be drawn from the
distribution.
:type sample_shape: torch.Size
:return: A random value or batch of random values (if parameters are
batched). The shape of the result should be `self.shape()`.
:rtype: torch.Tensor
"""
return self.rsample(sample_shape) if self.has_rsample else self.sample(sample_shape)
@property
def event_dim(self):
"""
:return: Number of dimensions of individual events.
:rtype: int
"""
return len(self.event_shape)
[docs] def shape(self, sample_shape=torch.Size()):
"""
The tensor shape of samples from this distribution.
Samples are of shape::
d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
:param sample_shape: the size of the iid batch to be drawn from the
distribution.
:type sample_shape: torch.Size
:return: Tensor shape of samples.
:rtype: torch.Size
"""
return sample_shape + self.batch_shape + self.event_shape
[docs] def expand(self, batch_shape):
"""
Expands a distribution to a desired
:attr:`~torch.distributions.distribution.Distribution.batch_shape`.
Note that this is more general than :meth:`expand_by` because
``d.expand_by(sample_shape)`` can be reduced to
``d.expand(sample_shape + d.batch_shape)``.
:param torch.Size batch_shape: The target ``batch_shape``. This must
compatible with ``self.batch_shape`` similar to the requirements
of :func:`torch.Tensor.expand`: the target ``batch_shape`` must
be at least as long as ``self.batch_shape``, and for each
non-singleton dim of ``self.batch_shape``, ``batch_shape`` must
either agree or be set to ``-1``.
:return: An expanded version of this distribution.
:rtype: :class:`ReshapedDistribution`
"""
batch_shape = list(batch_shape)
if len(batch_shape) < len(self.batch_shape):
raise ValueError("Expected len(batch_shape) >= len(self.batch_shape), "
"actual {} vs {}".format(len(batch_shape), len(self.batch_shape)))
# check sizes of existing dims
for dim in range(-1, -1 - len(self.batch_shape), -1):
if batch_shape[dim] == -1:
batch_shape[dim] = self.batch_shape[dim]
elif batch_shape[dim] != self.batch_shape[dim]:
if self.batch_shape[dim] != 1:
raise ValueError("Cannot broadcast dim {} of size {} to size {}".format(
dim, self.batch_shape[dim], batch_shape[dim]))
else:
raise NotImplementedError("https://github.com/uber/pyro/issues/1119")
sample_shape = batch_shape[:len(batch_shape) - len(self.batch_shape)]
return self.expand_by(sample_shape)
[docs] def expand_by(self, sample_shape):
"""
Expands a distribution by adding ``sample_shape`` to the left side of
its :attr:`~torch.distributions.distribution.Distribution.batch_shape`.
To expand internal dims of ``self.batch_shape`` from 1 to something
larger, use :meth:`expand` instead.
:param torch.Size sample_shape: The size of the iid batch to be drawn
from the distribution.
:return: An expanded version of this distribution.
:rtype: :class:`ReshapedDistribution`
"""
return ReshapedDistribution(self, sample_shape=sample_shape)
[docs] def reshape(self, sample_shape=None, extra_event_dims=None):
raise Exception('''
.reshape(sample_shape=s, extra_event_dims=n) was renamed and split into
.expand_by(sample_shape=s).independent(reinterpreted_batch_ndims=n).''')
[docs] def independent(self, reinterpreted_batch_ndims=None):
"""
Reinterprets the ``n`` rightmost dimensions of this distributions
:attr:`~torch.distributions.distribution.Distribution.batch_shape`
as event dims, adding them to the left side of
:attr:`~torch.distributions.distribution.Distribution.event_shape`.
Example:
.. doctest::
:hide:
>>> d0 = dist.Normal(torch.zeros(2, 3, 4, 5), torch.ones(2, 3, 4, 5))
>>> [d0.batch_shape, d0.event_shape]
[torch.Size([2, 3, 4, 5]), torch.Size([])]
>>> d1 = d0.independent(2)
>>> [d1.batch_shape, d1.event_shape]
[torch.Size([2, 3]), torch.Size([4, 5])]
>>> d2 = d1.independent(1)
>>> [d2.batch_shape, d2.event_shape]
[torch.Size([2]), torch.Size([3, 4, 5])]
>>> d3 = d1.independent(2)
>>> [d3.batch_shape, d3.event_shape]
[torch.Size([]), torch.Size([2, 3, 4, 5])]
:param int reinterpreted_batch_ndims: The number of batch dimensions
to reinterpret as event dimensions.
:return: A reshaped version of this distribution.
:rtype: :class:`ReshapedDistribution`
"""
if reinterpreted_batch_ndims is None:
reinterpreted_batch_ndims = len(self.batch_shape)
# TODO return pyro.distributions.torch.Independent(self, reinterpreted_batch_ndims)
return ReshapedDistribution(self, reinterpreted_batch_ndims=reinterpreted_batch_ndims)
[docs] def mask(self, mask):
"""
Masks a distribution by a zero-one tensor that is broadcastable to the
distributions :attr:`~torch.distributions.distribution.Distribution.batch_shape`.
:param torch.Tensor mask: A zero-one valued float tensor.
:return: A masked copy of this distribution.
:rtype: :class:`MaskedDistribution`
"""
return MaskedDistribution(self, mask)
[docs]class TorchDistribution(torch.distributions.Distribution, TorchDistributionMixin):
"""
Base class for PyTorch-compatible distributions with Pyro support.
This should be the base class for almost all new Pyro distributions.
.. note::
Parameters and data should be of type :class:`~torch.Tensor`
and all methods return type :class:`~torch.Tensor` unless
otherwise noted.
**Tensor Shapes**:
TorchDistributions provide a method ``.shape()`` for the tensor shape of samples::
x = d.sample(sample_shape)
assert x.shape == d.shape(sample_shape)
Pyro follows the same distribution shape semantics as PyTorch. It distinguishes
between three different roles for tensor shapes of samples:
- *sample shape* corresponds to the shape of the iid samples drawn from the distribution.
This is taken as an argument by the distribution's `sample` method.
- *batch shape* corresponds to non-identical (independent) parameterizations of
the distribution, inferred from the distribution's parameter shapes. This is
fixed for a distribution instance.
- *event shape* corresponds to the event dimensions of the distribution, which
is fixed for a distribution class. These are collapsed when we try to score
a sample from the distribution via `d.log_prob(x)`.
These shapes are related by the equation::
assert d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
Distributions provide a vectorized
:meth`~torch.distributions.distribution.Distribution.log_prob` method that
evaluates the log probability density of each event in a batch
independently, returning a tensor of shape
``sample_shape + d.batch_shape``::
x = d.sample(sample_shape)
assert x.shape == d.shape(sample_shape)
log_p = d.log_prob(x)
assert log_p.shape == sample_shape + d.batch_shape
**Implementing New Distributions**:
Derived classes must implement the methods
:meth:`~torch.distributions.distribution.Distribution.sample`
(or :meth:`~torch.distributions.distribution.Distribution.rsample` if
``.has_rsample == True``) and
:meth:`~torch.distributions.distribution.Distribution.log_prob`, and must
implement the properties
:attr:`~torch.distributions.distribution.Distribution.batch_shape`,
and :attr:`~torch.distributions.distribution.Distribution.event_shape`.
Discrete classes may also implement the
:meth:`~torch.distributions.distribution.Distribution.enumerate_support`
method to improve gradient estimates and set
``.has_enumerate_support = True``.
"""
pass
class ReshapedDistribution(TorchDistribution):
"""
Reshapes a distribution by adding ``sample_shape`` to its total shape
and adding ``reinterpreted_batch_ndims`` to its
:attr:`~torch.distributions.distribution.Distribution.event_shape`.
:param torch.Size sample_shape: The size of the iid batch to be drawn from
the distribution.
:param int reinterpreted_batch_ndims: The number of extra event dimensions that will
be considered dependent.
"""
arg_constraints = {}
def __init__(self, base_dist, sample_shape=torch.Size(), reinterpreted_batch_ndims=0):
sample_shape = torch.Size(sample_shape)
if reinterpreted_batch_ndims > len(sample_shape + base_dist.batch_shape):
raise ValueError('Expected reinterpreted_batch_ndims <= len(sample_shape + base_dist.batch_shape), '
'actual {} vs {}'.format(reinterpreted_batch_ndims,
len(sample_shape + base_dist.batch_shape)))
self.base_dist = base_dist
self.sample_shape = sample_shape
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
shape = sample_shape + base_dist.batch_shape + base_dist.event_shape
batch_dim = len(shape) - reinterpreted_batch_ndims - len(base_dist.event_shape)
batch_shape, event_shape = shape[:batch_dim], shape[batch_dim:]
super(ReshapedDistribution, self).__init__(batch_shape, event_shape)
def expand_by(self, sample_shape):
base_dist = self.base_dist
sample_shape = torch.Size(sample_shape) + self.sample_shape
reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
return ReshapedDistribution(base_dist, sample_shape, reinterpreted_batch_ndims)
def independent(self, reinterpreted_batch_ndims=None):
if reinterpreted_batch_ndims is None:
reinterpreted_batch_ndims = len(self.batch_shape)
base_dist = self.base_dist
sample_shape = self.sample_shape
reinterpreted_batch_ndims = self.reinterpreted_batch_ndims + reinterpreted_batch_ndims
return ReshapedDistribution(base_dist, sample_shape, reinterpreted_batch_ndims)
@property
def has_rsample(self):
return self.base_dist.has_rsample
@property
def has_enumerate_support(self):
return self.base_dist.has_enumerate_support
@constraints.dependent_property
def support(self):
return self.base_dist.support
def sample(self, sample_shape=torch.Size()):
return self.base_dist.sample(sample_shape + self.sample_shape)
def rsample(self, sample_shape=torch.Size()):
return self.base_dist.rsample(sample_shape + self.sample_shape)
def log_prob(self, value):
shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
return sum_rightmost(self.base_dist.log_prob(value), self.reinterpreted_batch_ndims).expand(shape)
def score_parts(self, value):
shape = broadcast_shape(self.batch_shape, value.shape[:value.dim() - self.event_dim])
log_prob, score_function, entropy_term = self.base_dist.score_parts(value)
log_prob = sum_rightmost(log_prob, self.reinterpreted_batch_ndims).expand(shape)
if not isinstance(score_function, numbers.Number):
score_function = sum_rightmost(score_function, self.reinterpreted_batch_ndims).expand(shape)
if not isinstance(entropy_term, numbers.Number):
entropy_term = sum_rightmost(entropy_term, self.reinterpreted_batch_ndims).expand(shape)
return ScoreParts(log_prob, score_function, entropy_term)
def enumerate_support(self):
if self.reinterpreted_batch_ndims:
raise NotImplementedError("Pyro does not enumerate over cartesian products")
samples = self.base_dist.enumerate_support()
if not self.sample_shape:
return samples
# Shift enumeration dim to correct location.
enum_shape, base_shape = samples.shape[:1], samples.shape[1:]
samples = samples.reshape(enum_shape + (1,) * len(self.sample_shape) + base_shape)
samples = samples.expand(enum_shape + self.sample_shape + base_shape)
return samples
@property
def mean(self):
return self.base_dist.mean.expand(self.batch_shape + self.event_shape)
@property
def variance(self):
return self.base_dist.variance.expand(self.batch_shape + self.event_shape)
class MaskedDistribution(TorchDistribution):
"""
Masks a distribution by a zero-one tensor that is broadcastable to the
distribution's :attr:`~torch.distributions.distribution.Distribution.batch_shape`.
:param torch.Tensor mask: A zero-one valued float tensor.
"""
arg_constraints = {}
def __init__(self, base_dist, mask):
if broadcast_shape(mask.shape, base_dist.batch_shape) != base_dist.batch_shape:
raise ValueError("Expected mask.shape to be broadcastable to base_dist.batch_shape, "
"actual {} vs {}".format(mask.shape, base_dist.batch_shape))
self.base_dist = base_dist
self._mask = mask
super(MaskedDistribution, self).__init__(base_dist.batch_shape, base_dist.event_shape)
@property
def has_rsample(self):
return self.base_dist.has_rsample
@property
def has_enumerate_support(self):
return self.base_dist.has_enumerate_support
@constraints.dependent_property
def support(self):
return self.base_dist.support
def sample(self, sample_shape=torch.Size()):
return self.base_dist.sample(sample_shape)
def rsample(self, sample_shape=torch.Size()):
return self.base_dist.rsample(sample_shape)
def log_prob(self, value):
return self.base_dist.log_prob(value) * self._mask
def score_parts(self, value):
return self.base_dist.score_parts(value) * self._mask
def enumerate_support(self):
return self.base_dist.enumerate_support()
@property
def mean(self):
return self.base_dist.mean
@property
def variance(self):
return self.base_dist.variance
```