Zuko in Pyro¶
This file contains helpers to use Zuko-based normalizing flows within Pyro piplines.
Accompanying tutorials can be found at tutorial/svi_flow_guide.ipynb and tutorial/vae_flow_prior.ipynb.
- class ZukoToPyro(dist: torch.distributions.distribution.Distribution)[source]¶
Wraps a Zuko distribution as a Pyro distribution.
If
dist
has anrsample_and_log_prob
method, like Zuko’s flows, it will be used when sampling instead ofrsample
. The returned log density will be cached for later scoring.- Parameters
dist (torch.distributions.Distribution) – A distribution instance.
flow = zuko.flows.MAF(features=5) # flow() is a torch.distributions.Distribution dist = flow() x = dist.sample((2, 3)) log_p = dist.log_prob(x) # ZukoToPyro(flow()) is a pyro.distributions.Distribution dist = ZukoToPyro(flow()) x = dist((2, 3)) log_p = dist.log_prob(x) with pyro.plate("data", 42): z = pyro.sample("z", dist)