Source code for pyro.distributions.polya_gamma

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math

import torch
from torch.distributions import constraints

from pyro.distributions.torch import Exponential
from pyro.distributions.torch_distribution import TorchDistribution


[docs]class TruncatedPolyaGamma(TorchDistribution): """ This is a PolyaGamma(1, 0) distribution truncated to have finite support in the interval (0, 2.5). See [1] for details. As a consequence of the truncation the `log_prob` method is only accurate to about six decimal places. In addition the provided sampler is a rough approximation that is only meant to be used in contexts where sample accuracy is not important (e.g. in initialization). Broadly, this implementation is only intended for usage in cases where good approximations of the `log_prob` are sufficient, as is the case e.g. in HMC. :param tensor prototype: A prototype tensor of arbitrary shape used to determine the `dtype` and `device` returned by `sample` and `log_prob`. References [1] 'Bayesian inference for logistic models using Polya-Gamma latent variables' Nicholas G. Polson, James G. Scott, Jesse Windle. """ truncation_point = 2.5 num_log_prob_terms = 7 num_gamma_variates = 8 assert num_log_prob_terms % 2 == 1 arg_constraints = {} support = constraints.interval(0.0, truncation_point) has_rsample = False def __init__(self, prototype, validate_args=None): self.prototype = prototype super(TruncatedPolyaGamma, self).__init__( batch_shape=(), event_shape=(), validate_args=validate_args )
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(TruncatedPolyaGamma, _instance) super(TruncatedPolyaGamma, new).__init__( batch_shape, self.event_shape, validate_args=False ) new._validate_args = self.__dict__.get("_validate_args") new.prototype = self.prototype return new
[docs] def sample(self, sample_shape=()): denom = torch.arange( 0.5, self.num_gamma_variates, device=self.prototype.device ).pow(2.0) ones = self.prototype.new_ones((self.num_gamma_variates)) x = Exponential(ones).sample(self.batch_shape + sample_shape) x = (x / denom).sum(-1) return torch.clamp(x * (0.5 / math.pi**2), max=self.truncation_point)
[docs] def log_prob(self, value): value = value.unsqueeze(-1) two_n_plus_one = ( 2.0 * torch.arange(0, self.num_log_prob_terms, device=self.prototype.device) + 1.0 ) log_terms = ( two_n_plus_one.log() - 1.5 * value.log() - 0.125 * two_n_plus_one.pow(2.0) / value ) even_terms = log_terms[..., ::2] odd_terms = log_terms[..., 1::2] sum_even = torch.logsumexp(even_terms, dim=-1).exp() sum_odd = torch.logsumexp(odd_terms, dim=-1).exp() return (sum_even - sum_odd).log() - 0.5 * math.log(2.0 * math.pi)