Source code for pyro.distributions.affine_beta

# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints
from torch.distributions.transforms import AffineTransform

from .torch import Beta, TransformedDistribution
from .util import broadcast_shape


[docs]class AffineBeta(TransformedDistribution): r""" Beta distribution scaled by :attr:`scale` and shifted by :attr:`loc`:: X ~ Beta(concentration1, concentration0) f(X) = loc + scale * X Y = f(X) ~ AffineBeta(concentration1, concentration0, loc, scale) :param concentration1: 1st concentration parameter (alpha) for the Beta distribution. :type concentration1: float or torch.Tensor :param concentration0: 2nd concentration parameter (beta) for the Beta distribution. :type concentration0: float or torch.Tensor :param loc: location parameter. :type loc: float or torch.Tensor :param scale: scale parameter. :type scale: float or torch.Tensor """ arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, "loc": constraints.real, "scale": constraints.positive, } def __init__(self, concentration1, concentration0, loc, scale, validate_args=None): base_dist = Beta(concentration1, concentration0, validate_args=validate_args) super(AffineBeta, self).__init__( base_dist, AffineTransform(loc=loc, scale=scale), validate_args=validate_args, )
[docs] @staticmethod def infer_shapes(concentration1, concentration0, loc, scale): batch_shape = broadcast_shape(concentration1, concentration0, loc, scale) event_shape = torch.Size() return batch_shape, event_shape
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(AffineBeta, _instance) return super(AffineBeta, self).expand(batch_shape, _instance=new)
[docs] def sample(self, sample_shape=torch.Size()): """ Generates a sample from `Beta` distribution and applies `AffineTransform`. Additionally clamps the output in order to avoid `NaN` and `Inf` values in the gradients. """ with torch.no_grad(): x = self.base_dist.sample(sample_shape) for transform in self.transforms: x = transform(x) eps = torch.finfo(x.dtype).eps * self.scale x = torch.min(torch.max(x, self.low + eps), self.high - eps) return x
[docs] def rsample(self, sample_shape=torch.Size()): """ Generates a sample from `Beta` distribution and applies `AffineTransform`. Additionally clamps the output in order to avoid `NaN` and `Inf` values in the gradients. """ x = self.base_dist.rsample(sample_shape) for transform in self.transforms: x = transform(x) eps = torch.finfo(x.dtype).eps * self.scale x = torch.min(torch.max(x, self.low + eps), self.high - eps) return x
@constraints.dependent_property def support(self): return constraints.interval(self.low, self.high) @property def concentration1(self): return self.base_dist.concentration1 @property def concentration0(self): return self.base_dist.concentration0 @property def sample_size(self): return self.concentration1 + self.concentration0 @property def loc(self): return torch.as_tensor(self.transforms[0].loc) @property def scale(self): return torch.as_tensor(self.transforms[0].scale) @property def low(self): return self.loc @property def high(self): return self.loc + self.scale @property def mean(self): return self.loc + self.scale * self.base_dist.mean @property def variance(self): return self.scale.pow(2) * self.base_dist.variance