Source code for pyro.distributions.distribution

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import functools
import inspect
from abc import ABCMeta, abstractmethod

import torch

from pyro.distributions.score_parts import ScoreParts


class DistributionMeta(ABCMeta):
    def __init__(cls, *args, **kwargs):
        signature = inspect.signature(functools.partial(cls.__init__, None))
        cls.__signature__ = signature
        return super().__init__(*args, **kwargs)

    def __call__(cls, *args, **kwargs):
        for coerce_ in COERCIONS:
            result = coerce_(cls, args, kwargs)
            if result is not None:
                return result
        return super().__call__(*args, **kwargs)

[docs]class Distribution(metaclass=DistributionMeta): """ Base class for parameterized probability distributions. Distributions in Pyro are stochastic function objects with :meth:`sample` and :meth:`log_prob` methods. Distribution are stochastic functions with fixed parameters:: d = dist.Bernoulli(param) x = d() # Draws a random sample. p = d.log_prob(x) # Evaluates log probability of x. **Implementing New Distributions**: Derived classes must implement the methods: :meth:`sample`, :meth:`log_prob`. **Examples**: Take a look at the `examples <>`_ to see how they interact with inference algorithms. """ has_rsample = False has_enumerate_support = False
[docs] def __call__(self, *args, **kwargs): """ Samples a random value (just an alias for ``.sample(*args, **kwargs)``). For tensor distributions, the returned tensor should have the same ``.shape`` as the parameters. :return: A random value. :rtype: torch.Tensor """ return self.sample(*args, **kwargs)
[docs] @abstractmethod def sample(self, *args, **kwargs): """ Samples a random value. For tensor distributions, the returned tensor should have the same ``.shape`` as the parameters, unless otherwise noted. :param sample_shape: the size of the iid batch to be drawn from the distribution. :type sample_shape: torch.Size :return: A random value or batch of random values (if parameters are batched). The shape of the result should be ``self.shape()``. :rtype: torch.Tensor """ raise NotImplementedError
[docs] @abstractmethod def log_prob(self, x, *args, **kwargs): """ Evaluates log probability densities for each of a batch of samples. :param torch.Tensor x: A single value or a batch of values batched along axis 0. :return: log probability densities as a one-dimensional :class:`~torch.Tensor` with same batch size as value and params. The shape of the result should be ``self.batch_size``. :rtype: torch.Tensor """ raise NotImplementedError
[docs] def score_parts(self, x, *args, **kwargs): """ Computes ingredients for stochastic gradient estimators of ELBO. The default implementation is correct both for non-reparameterized and for fully reparameterized distributions. Partially reparameterized distributions should override this method to compute correct `.score_function` and `.entropy_term` parts. Setting ``.has_rsample`` on a distribution instance will determine whether inference engines like :class:`~pyro.infer.svi.SVI` use reparameterized samplers or the score function estimator. :param torch.Tensor x: A single value or batch of values. :return: A `ScoreParts` object containing parts of the ELBO estimator. :rtype: ScoreParts """ log_prob = self.log_prob(x, *args, **kwargs) if self.has_rsample: return ScoreParts( log_prob=log_prob, score_function=0, entropy_term=log_prob ) else: # XXX should the user be able to control inclusion of the entropy term? # See Roeder, Wu, Duvenaud (2017) "Sticking the Landing" return ScoreParts( log_prob=log_prob, score_function=log_prob, entropy_term=0 )
[docs] def enumerate_support(self, expand: bool = True) -> torch.Tensor: """ Returns a representation of the parametrized distribution's support, along the first dimension. This is implemented only by discrete distributions. Note that this returns support values of all the batched RVs in lock-step, rather than the full cartesian product. :param bool expand: whether to expand the result to a tensor of shape ``(n,) + batch_shape + event_shape``. If false, the return value has unexpanded shape ``(n,) + (1,)*len(batch_shape) + event_shape`` which can be broadcasted to the full shape. :return: An iterator over the distribution's discrete support. :rtype: iterator """ raise NotImplementedError( "Support not implemented for {}".format(type(self).__name__) )
[docs] def conjugate_update(self, other): """ EXPERIMENTAL Creates an updated distribution fusing information from another compatible distribution. This is supported by only a few conjugate distributions. This should satisfy the equation:: fg, log_normalizer = f.conjugate_update(g) assert f.log_prob(x) + g.log_prob(x) == fg.log_prob(x) + log_normalizer Note this is equivalent to :obj:`funsor.ops.add` on :class:`~funsor.terms.Funsor` distributions, but we return a lazy sum ``(updated, log_normalizer)`` because PyTorch distributions must be normalized. Thus :meth:`conjugate_update` should commute with :func:`~funsor.pyro.convert.dist_to_funsor` and :func:`~funsor.pyro.convert.tensor_to_funsor` :: dist_to_funsor(f) + dist_to_funsor(g) == dist_to_funsor(fg) + tensor_to_funsor(log_normalizer) :param other: A distribution representing ``p(data|latent)`` but normalized over ``latent`` rather than ``data``. Here ``latent`` is a candidate sample from ``self`` and ``data`` is a ground observation of unrelated type. :return: a pair ``(updated,log_normalizer)`` where ``updated`` is an updated distribution of type ``type(self)``, and ``log_normalizer`` is a :class:`~torch.Tensor` representing the normalization factor. """ raise NotImplementedError( "{} does not support .conjugate_update()".format(type(self).__name__) )
[docs] def has_rsample_(self, value): """ Force reparameterized or detached sampling on a single distribution instance. This sets the ``.has_rsample`` attribute in-place. This is useful to instruct inference algorithms to avoid reparameterized gradients for variables that discontinuously determine downstream control flow. :param bool value: Whether samples will be pathwise differentiable. :return: self :rtype: Distribution """ if not (value is True or value is False): raise ValueError("Expected value in [False,True], actual {}".format(value)) self.has_rsample = value return self
@property def rv(self): """ EXPERIMENTAL Switch to the Random Variable DSL for applying transformations to random variables. Supports either chaining operations or arithmetic operator overloading. Example usage:: # This should be equivalent to an Exponential distribution. Uniform(0, 1).rv.log().neg().dist # These two distributions Y1, Y2 should be the same X = Uniform(0, 1).rv Y1 = X.mul(4).pow(0.5).sub(1).abs().neg().dist Y2 = (-abs((4*X)**(0.5) - 1)).dist :return: A :class: `~pyro.contrib.randomvariable.random_variable.RandomVariable` object wrapping this distribution. :rtype: ~pyro.contrib.randomvariable.random_variable.RandomVariable """ from pyro.contrib.randomvariable import RandomVariable return RandomVariable(self)