Source code for pyro.distributions.mixture

# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property

from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape


class MaskedConstraint(constraints.Constraint):
    """
    Combines two constraints interleaved elementwise by a mask.

    :param torch.Tensor mask: boolean mask tensor (of dtype ``torch.bool``)
    :param torch.constraints.Constraint constraint0: constraint that holds
        wherever ``mask == 0``
    :param torch.constraints.Constraint constraint1: constraint that holds
        wherever ``mask == 1``
    """

    def __init__(self, mask, constraint0, constraint1):
        self.mask = mask
        self.constraint0 = constraint0
        self.constraint1 = constraint1

    def check(self, value):
        result = self.constraint0.check(value)
        mask = (
            self.mask.expand(result.shape)
            if result.shape != self.mask.shape
            else self.mask
        )
        result[mask] = self.constraint1.check(value)[mask]
        return result


[docs]class MaskedMixture(TorchDistribution): """ A masked deterministic mixture of two distributions. This is useful when the mask is sampled from another distribution, possibly correlated across the batch. Often the mask can be marginalized out via enumeration. Example:: change_point = pyro.sample("change_point", dist.Categorical(torch.ones(len(data) + 1)), infer={'enumerate': 'parallel'}) mask = torch.arange(len(data), dtype=torch.long) >= changepoint with pyro.plate("data", len(data)): pyro.sample("obs", MaskedMixture(mask, dist1, dist2), obs=data) :param torch.Tensor mask: A boolean tensor toggling between ``component0`` and ``component1``. :param pyro.distributions.TorchDistribution component0: a distribution for batch elements ``mask == False``. :param pyro.distributions.TorchDistribution component1: a distribution for batch elements ``mask == True``. """ arg_constraints = {} # nothing can be constrained def __init__(self, mask, component0, component1, validate_args=None): if not torch.is_tensor(mask) or mask.dtype != torch.bool: raise ValueError( "Expected mask to be a BoolTensor but got {}".format(type(mask)) ) if component0.event_shape != component1.event_shape: raise ValueError( "components event_shape disagree: {} vs {}".format( component0.event_shape, component1.event_shape ) ) batch_shape = broadcast_shape( mask.shape, component0.batch_shape, component1.batch_shape ) if mask.shape != batch_shape: mask = mask.expand(batch_shape) if component0.batch_shape != batch_shape: component0 = component0.expand(batch_shape) if component1.batch_shape != batch_shape: component1 = component1.expand(batch_shape) self.mask = mask self.component0 = component0 self.component1 = component1 super().__init__(batch_shape, component0.event_shape, validate_args) # We need to disable _validate_sample on each component since samples are only valid on the # component from which they are drawn. Instead we perform validation using a MaskedConstraint. self.component0._validate_args = False self.component1._validate_args = False @property def has_rsample(self): return self.component0.has_rsample and self.component1.has_rsample @constraints.dependent_property def support(self): if self.component0.support is self.component1.support: return self.component0.support return MaskedConstraint( self.mask, self.component0.support, self.component1.support )
[docs] def expand(self, batch_shape): try: return super().expand(batch_shape) except NotImplementedError: mask = self.mask.expand(batch_shape) component0 = self.component0.expand(batch_shape) component1 = self.component1.expand(batch_shape) return type(self)(mask, component0, component1)
[docs] def sample(self, sample_shape=torch.Size()): mask = self.mask.reshape(self.mask.shape + (1,) * self.event_dim) mask = mask.expand(sample_shape + self.shape()) result = torch.where( mask, self.component1.sample(sample_shape), self.component0.sample(sample_shape), ) return result
[docs] def rsample(self, sample_shape=torch.Size()): mask = self.mask.reshape(self.mask.shape + (1,) * self.event_dim) mask = mask.expand(sample_shape + self.shape()) result = torch.where( mask, self.component1.rsample(sample_shape), self.component0.rsample(sample_shape), ) return result
[docs] def log_prob(self, value): value_shape = broadcast_shape(value.shape, self.batch_shape + self.event_shape) if value.shape != value_shape: value = value.expand(value_shape) if self._validate_args: self._validate_sample(value) mask_shape = value_shape[: len(value_shape) - len(self.event_shape)] mask = self.mask if mask.shape != mask_shape: mask = mask.expand(mask_shape) result = torch.where( mask, self.component1.log_prob(value), self.component0.log_prob(value) ) return result
@lazy_property def mean(self): result = self.component0.mean.clone() result[self.mask] = self.component1.mean[self.mask] return result @lazy_property def variance(self): result = self.component0.variance.clone() result[self.mask] = self.component1.variance[self.mask] return result