Zuko in Pyro

This file contains helpers to use Zuko-based normalizing flows within Pyro piplines.

Accompanying tutorials can be found at tutorial/svi_flow_guide.ipynb and tutorial/vae_flow_prior.ipynb.

class ZukoToPyro(dist: torch.distributions.distribution.Distribution)[source]

Wraps a Zuko distribution as a Pyro distribution.

If dist has an rsample_and_log_prob method, like Zuko’s flows, it will be used when sampling instead of rsample. The returned log density will be cached for later scoring.

Parameters

dist (torch.distributions.Distribution) – A distribution instance.

flow = zuko.flows.MAF(features=5)

# flow() is a torch.distributions.Distribution

dist = flow()
x = dist.sample((2, 3))
log_p = dist.log_prob(x)

# ZukoToPyro(flow()) is a pyro.distributions.Distribution

dist = ZukoToPyro(flow())
x = dist((2, 3))
log_p = dist.log_prob(x)

with pyro.plate("data", 42):
    z = pyro.sample("z", dist)