Source code for pyro.distributions.conditional
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
import torch
import torch.nn
from .torch import TransformedDistribution
[docs]class ConditionalDistribution(ABC):
[docs] @abstractmethod
def condition(self, context):
""":rtype: torch.distributions.Distribution"""
raise NotImplementedError
[docs]class ConditionalTransform(ABC):
[docs] @abstractmethod
def condition(self, context):
""":rtype: torch.distributions.Transform"""
raise NotImplementedError
[docs]class ConditionalTransformModule(ConditionalTransform, torch.nn.Module):
"""
Conditional transforms with learnable parameters such as normalizing flows should inherit from this class rather
than :class:`~pyro.distributions.conditional.ConditionalTransform` so they are also a subclass of
:class:`~torch.nn.Module` and inherit all the useful methods of that class.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __hash__(self):
return super().__hash__()
class ConstantConditionalDistribution(ConditionalDistribution):
def __init__(self, base_dist):
assert isinstance(base_dist, torch.distributions.Distribution)
self.base_dist = base_dist
def condition(self, context):
return self.base_dist
class ConstantConditionalTransform(ConditionalTransform):
def __init__(self, transform):
assert isinstance(transform, torch.distributions.Transform)
self.transform = transform
def condition(self, context):
return self.transform
def clear_cache(self):
self.transform.clear_cache()
[docs]class ConditionalTransformedDistribution(ConditionalDistribution):
def __init__(self, base_dist, transforms):
self.base_dist = (
base_dist
if isinstance(base_dist, ConditionalDistribution)
else ConstantConditionalDistribution(base_dist)
)
self.transforms = [
t
if isinstance(t, ConditionalTransform)
else ConstantConditionalTransform(t)
for t in transforms
]
[docs] def condition(self, context):
base_dist = self.base_dist.condition(context)
transforms = [t.condition(context) for t in self.transforms]
return TransformedDistribution(base_dist, transforms)