Inference utilities¶
Model inspection¶
- get_dependencies(model: Callable, model_args: Optional[tuple] = None, model_kwargs: Optional[dict] = None) Dict[str, object] [source]¶
Infers dependency structure about a conditioned model.
This returns a nested dictionary with structure like:
{ "prior_dependencies": { "variable1": {"variable1": set()}, "variable2": {"variable1": set(), "variable2": set()}, ... }, "posterior_dependencies": { "variable1": {"variable1": {"plate1"}, "variable2": set()}, ... }, }
where
prior_dependencies is a dict mapping downstream latent and observed variables to dictionaries mapping upstream latent variables on which they depend to sets of plates inducing full dependencies. That is, included plates introduce quadratically many dependencies as in complete-bipartite graphs, whereas excluded plates introduce only linearly many dependencies as in independent sets of parallel edges. Prior dependencies follow the original model order.
posterior_dependencies is a similar dict, but mapping latent variables to the latent or observed sites on which they depend in the posterior. Posterior dependencies are reversed from the model order.
Dependencies elide
pyro.deterministic
sites andpyro.sample(..., Delta(...))
sites.Examples
Here is a simple example with no plates. We see every node depends on itself, and only the latent variables appear in the posterior:
def model_1(): a = pyro.sample("a", dist.Normal(0, 1)) pyro.sample("b", dist.Normal(a, 1), obs=torch.tensor(0.0)) assert get_dependencies(model_1) == { "prior_dependencies": { "a": {"a": set()}, "b": {"a": set(), "b": set()}, }, "posterior_dependencies": { "a": {"a": set(), "b": set()}, }, }
Here is an example where two variables
a
andb
start out conditionally independent in the prior, but become conditionally dependent in the posterior to the so-called collider variablec
on which they both depend. This is called “moralization” in the graphical model literature:def model_2(): a = pyro.sample("a", dist.Normal(0, 1)) b = pyro.sample("b", dist.LogNormal(0, 1)) c = pyro.sample("c", dist.Normal(a, b)) pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.)) assert get_dependencies(model_2) == { "prior_dependencies": { "a": {"a": set()}, "b": {"b": set()}, "c": {"a": set(), "b": set(), "c": set()}, "d": {"c": set(), "d": set()}, }, "posterior_dependencies": { "a": {"a": set(), "b": set(), "c": set()}, "b": {"b": set(), "c": set()}, "c": {"c": set(), "d": set()}, }, }
Dependencies can be more complex in the presence of plates. So far all the dict values have been empty sets of plates, but in the following posterior we see that
a
depends on itself across the platep
. This means that, among the elements ofa
, e.g.a[0]
depends ona[1]
(this is why we explicitly allow variables to depend on themselves):def model_3(): with pyro.plate("p", 5): a = pyro.sample("a", dist.Normal(0, 1)) pyro.sample("b", dist.Normal(a.sum(), 1), obs=torch.tensor(0.0)) assert get_dependencies(model_3) == { "prior_dependencies": { "a": {"a": set()}, "b": {"a": set(), "b": set()}, }, "posterior_dependencies": { "a": {"a": {"p"}, "b": set()}, }, }
- [1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)
“Faithful inversion of generative models for effective amortized inference” https://dl.acm.org/doi/10.5555/3327144.3327229
- render_model(model: Callable, model_args: Optional[Union[tuple, List[tuple]]] = None, model_kwargs: Optional[Union[dict, List[dict]]] = None, filename: Optional[str] = None, render_distributions: bool = False, render_params: bool = False) graphviz.graphs.Digraph [source]¶
Renders a model using graphviz .
If
filename
is provided, this saves an image; otherwise this draws the graph. For example usage see the model rendering tutorial .- Parameters
model – Model to render.
model_args – Tuple of positional arguments to pass to the model, or list of tuples for semisupervised models.
model_kwargs – Dict of keyword arguments to pass to the model, or list of dicts for semisupervised models.
filename (str) – Name of file or path to file to save rendered model in.
render_distributions (bool) – Whether to include RV distribution annotations (and param constraints) in the plot.
render_params (bool) – Whether to show params inthe plot.
- Returns
A model graph.
- Return type
Interactive prior tuning¶
- class Resampler(guide: Callable, simulator: Optional[Callable] = None, *, num_guide_samples: int, max_plate_nesting: Optional[int] = None)[source]¶
Resampler for interactive tuning of generative models, typically when preforming prior predictive checks as an early step of Bayesian workflow.
This is intended as a computational cache to speed up the interactive tuning of the parameters of prior distributions based on samples from a downstream simulation. The idea is that the simulation can be expensive, but that when one slightly tweaks parameters of the parameter distribution then one can reuse most of the previous samples via importance resampling.
- Parameters
guide (callable) – A pyro model that takes no arguments. The guide should be diffuse, covering more space than the subsequent
model
passed tosample()
. Must be vectorizable viapyro.plate
.simulator (callable) – An optional larger pyro model with a superset of the guide’s latent variables. Must be vectorizable via
pyro.plate
.num_guide_samples (int) – Number of inital samples to draw from the guide. This should be much larger than the
num_samples
requested in subsequent calls tosample()
.max_plate_nesting (int) – The maximum plate nesting in the model. If absent this will be guessed by running the guide.
- sample(model: Callable, num_samples: int, stable: bool = True) Dict[str, torch.Tensor] [source]¶
Draws a set of at most
num_samples
many model samples, optionally extended by thesimulator
.Internally this importance resamples the samples generated by the
guide
in.__init__()
, and does not rerun theguide
orsimulator
. If the original guide samples poorly cover the model distribution, samples will show low diversity.- Parameters
model (callable) – A model with the same latent variables as the original
guide
. Must be vectorizable viapyro.plate
.num_samples (int) – The number of samples to draw.
stable (bool) – Whether to use piecewise-constant multinomial sampling. Set to True for visualization, False for Monte Carlo integration. Defaults to True.
- Returns
A dictionary of stacked samples.
- Return type