Inference utilities

enable_validation(is_validate)[source]
is_validation_enabled()[source]
validation_enabled(is_validate=True)[source]

Model inspection

get_dependencies(model: Callable, model_args: Optional[tuple] = None, model_kwargs: Optional[dict] = None) Dict[str, object][source]

Infers dependency structure about a conditioned model.

This returns a nested dictionary with structure like:

{
    "prior_dependencies": {
        "variable1": {"variable1": set()},
        "variable2": {"variable1": set(), "variable2": set()},
        ...
    },
    "posterior_dependencies": {
        "variable1": {"variable1": {"plate1"}, "variable2": set()},
        ...
    },
}

where

  • prior_dependencies is a dict mapping downstream latent and observed variables to dictionaries mapping upstream latent variables on which they depend to sets of plates inducing full dependencies. That is, included plates introduce quadratically many dependencies as in complete-bipartite graphs, whereas excluded plates introduce only linearly many dependencies as in independent sets of parallel edges. Prior dependencies follow the original model order.

  • posterior_dependencies is a similar dict, but mapping latent variables to the latent or observed sits on which they depend in the posterior. Posterior dependencies are reversed from the model order.

Dependencies elide pyro.deterministic sites and pyro.sample(..., Delta(...)) sites.

Examples

Here is a simple example with no plates. We see every node depends on itself, and only the latent variables appear in the posterior:

def model_1():
    a = pyro.sample("a", dist.Normal(0, 1))
    pyro.sample("b", dist.Normal(a, 1), obs=torch.tensor(0.0))

assert get_dependencies(model_1) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"a": set(), "b": set()},
    },
    "posterior_dependencies": {
        "a": {"a": set(), "b": set()},
    },
}

Here is an example where two variables a and b start out conditionally independent in the prior, but become conditionally dependent in the posterior do the so-called collider variable c on which they both depend. This is called “moralization” in the graphical model literature:

def model_2():
    a = pyro.sample("a", dist.Normal(0, 1))
    b = pyro.sample("b", dist.LogNormal(0, 1))
    c = pyro.sample("c", dist.Normal(a, b))
    pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.))

assert get_dependencies(model_2) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"b": set()},
        "c": {"a": set(), "b": set(), "c": set()},
        "d": {"c": set(), "d": set()},
    },
    "posterior_dependencies": {
        "a": {"a": set(), "b": set(), "c": set()},
        "b": {"b": set(), "c": set()},
        "c": {"c": set(), "d": set()},
    },
}

Dependencies can be more complex in the presence of plates. So far all the dict values have been empty sets of plates, but in the following posterior we see that c depends on itself across the plate p. This means that, among the elements of c, e.g. c[0] depends on c[1] (this is why we explicitly allow variables to depend on themselves):

def model_3():
    with pyro.plate("p", 5):
        a = pyro.sample("a", dist.Normal(0, 1))
    pyro.sample("b", dist.Normal(a.sum(), 1), obs=torch.tensor(0.0))

assert get_dependencies(model_3) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"a": set(), "b": set()},
    },
    "posterior_dependencies": {
        "a": {"a": {"p"}, "b": set()},
    },
}
[1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)

“Faithful inversion of generative models for effective amortized inference” https://dl.acm.org/doi/10.5555/3327144.3327229

Parameters
  • model (callable) – A model.

  • model_args (tuple) – Optional tuple of model args.

  • model_kwargs (dict) – Optional dict of model kwargs.

Returns

A dictionary of metadata (see above).

Return type

dict

render_model(model: Callable, model_args: Optional[tuple] = None, model_kwargs: Optional[dict] = None, filename: Optional[str] = None, render_distributions: bool = False, render_params: bool = False) graphviz.graphs.Digraph[source]

Renders a model using graphviz .

If filename is provided, this saves an image; otherwise this draws the graph. For example usage see the model rendering tutorial .

Parameters
  • model – Model to render.

  • model_args – Positional arguments to pass to the model.

  • model_kwargs – Keyword arguments to pass to the model.

  • filename (str) – File to save rendered model in.

  • render_distributions (bool) – Whether to include RV distribution annotations (and param constraints) in the plot.

  • render_params (bool) – Whether to show params inthe plot.

Returns

A model graph.

Return type

graphviz.Digraph