# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
"""
This file contains helpers to use `Zuko <https://zuko.readthedocs.io/>`_-based
normalizing flows within Pyro piplines.
Accompanying tutorials can be found at `tutorial/svi_flow_guide.ipynb` and
`tutorial/vae_flow_prior.ipynb`.
"""
import torch
from torch import Size, Tensor
import pyro
[docs]class ZukoToPyro(pyro.distributions.TorchDistribution):
r"""Wraps a Zuko distribution as a Pyro distribution.
If ``dist`` has an ``rsample_and_log_prob`` method, like Zuko's flows, it will be
used when sampling instead of ``rsample``. The returned log density will be cached
for later scoring.
:param dist: A distribution instance.
:type dist: torch.distributions.Distribution
.. code-block:: python
flow = zuko.flows.MAF(features=5)
# flow() is a torch.distributions.Distribution
dist = flow()
x = dist.sample((2, 3))
log_p = dist.log_prob(x)
# ZukoToPyro(flow()) is a pyro.distributions.Distribution
dist = ZukoToPyro(flow())
x = dist((2, 3))
log_p = dist.log_prob(x)
with pyro.plate("data", 42):
z = pyro.sample("z", dist)
"""
def __init__(self, dist: torch.distributions.Distribution):
self.dist = dist
self.cache = {}
@property
def has_rsample(self) -> bool:
return self.dist.has_rsample
@property
def event_shape(self) -> Size:
return self.dist.event_shape
@property
def batch_shape(self) -> Size:
return self.dist.batch_shape
def __call__(self, shape: Size = ()) -> Tensor:
if hasattr(self.dist, "rsample_and_log_prob"): # fast sampling + scoring
x, self.cache[x] = self.dist.rsample_and_log_prob(shape)
elif self.has_rsample:
x = self.dist.rsample(shape)
else:
x = self.dist.sample(shape)
return x
def log_prob(self, x: Tensor) -> Tensor:
if x in self.cache:
return self.cache[x]
else:
return self.dist.log_prob(x)
def expand(self, *args, **kwargs):
return ZukoToPyro(self.dist.expand(*args, **kwargs))