import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property

from pyro.distributions.torch_distribution import TorchDistribution

"""
Combines two constraints interleaved elementwise by a mask.

:param torch.Tensor mask: boolean mask tensor (of dtype torch.bool)
:param torch.constraints.Constraint constraint0: constraint that holds
wherever mask == 0
:param torch.constraints.Constraint constraint1: constraint that holds
wherever mask == 1
"""

self.constraint0 = constraint0
self.constraint1 = constraint1

def check(self, value):
result = self.constraint0.check(value)
)
return result

"""
A masked deterministic mixture of two distributions.

This is useful when the mask is sampled from another distribution,
possibly correlated across the batch. Often the mask can be
marginalized out via enumeration.

Example::

change_point = pyro.sample("change_point",
dist.Categorical(torch.ones(len(data) + 1)),
infer={'enumerate': 'parallel'})
mask = torch.arange(len(data), dtype=torch.long) >= changepoint
with pyro.plate("data", len(data)):

:param torch.Tensor mask: A boolean tensor toggling between component0
and component1.
:param pyro.distributions.TorchDistribution component0: a distribution
for batch elements mask == False.
:param pyro.distributions.TorchDistribution component1: a distribution
for batch elements mask == True.
"""

arg_constraints = {}  # nothing can be constrained

def __init__(self, mask, component0, component1, validate_args=None):
raise ValueError(
)
if component0.event_shape != component1.event_shape:
raise ValueError(
"components event_shape disagree: {} vs {}".format(
component0.event_shape, component1.event_shape
)
)
)
if component0.batch_shape != batch_shape:
component0 = component0.expand(batch_shape)
if component1.batch_shape != batch_shape:
component1 = component1.expand(batch_shape)

self.component0 = component0
self.component1 = component1
super().__init__(batch_shape, component0.event_shape, validate_args)

# We need to disable _validate_sample on each component since samples are only valid on the
# component from which they are drawn. Instead we perform validation using a MaskedConstraint.
self.component0._validate_args = False
self.component1._validate_args = False

@property
def has_rsample(self):
return self.component0.has_rsample and self.component1.has_rsample

@constraints.dependent_property
def support(self):
if self.component0.support is self.component1.support:
return self.component0.support
)

[docs]    def expand(self, batch_shape):
try:
return super().expand(batch_shape)
except NotImplementedError:
component0 = self.component0.expand(batch_shape)
component1 = self.component1.expand(batch_shape)

[docs]    def sample(self, sample_shape=torch.Size()):
result = torch.where(
self.component1.sample(sample_shape),
self.component0.sample(sample_shape),
)
return result

[docs]    def rsample(self, sample_shape=torch.Size()):
result = torch.where(
self.component1.rsample(sample_shape),
self.component0.rsample(sample_shape),
)
return result

[docs]    def log_prob(self, value):
value_shape = broadcast_shape(value.shape, self.batch_shape + self.event_shape)
if value.shape != value_shape:
value = value.expand(value_shape)
if self._validate_args:
self._validate_sample(value)
mask_shape = value_shape[: len(value_shape) - len(self.event_shape)]
result = torch.where(
)
return result

[docs]    @lazy_property
def mean(self):
result = self.component0.mean.clone()