Poutine (Effect handlers)¶
Beneath the built-in inference algorithms, Pyro has a library of composable effect handlers for creating new inference algorithms and working with probabilistic programs. Pyro’s inference algorithms are all built by applying these handlers to stochastic functions. In order to get a general understanding what effect handlers are and what problem they solve, read An Introduction to Algebraic Effects and Handlers by Matija Pretnar.
Handlers¶
Poutine is a library of composable effect handlers for recording and modifying the behavior of Pyro programs. These lower-level ingredients simplify the implementation of new inference algorithms and behavior.
Handlers can be used as higher-order functions, decorators, or context managers to modify the behavior of functions or blocks of code:
For example, consider the following Pyro program:
>>> def model(x):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(x, s))
... return z ** 2
We can mark sample sites as observed using condition
,
which returns a callable with the same input and output signatures as model
:
>>> conditioned_model = poutine.condition(model, data={"z": 1.0})
We can also use handlers as decorators:
>>> @pyro.condition(data={"z": 1.0})
... def model(x):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(x, s))
... return z ** 2
Or as context managers:
>>> with pyro.condition(data={"z": 1.0}):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(0., s))
... y = z ** 2
Handlers compose freely:
>>> conditioned_model = poutine.condition(model, data={"z": 1.0})
>>> traced_model = poutine.trace(conditioned_model)
Many inference algorithms or algorithmic components can be implemented in just a few lines of code:
guide_tr = poutine.trace(guide).get_trace(...)
model_tr = poutine.trace(poutine.replay(conditioned_model, trace=guide_tr)).get_trace(...)
monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
- block(fn: None = None, hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None) pyro.poutine.block_messenger.BlockMessenger [source]¶
- block(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
BlockMessenger
This handler selectively hides Pyro primitive sites from the outside world. Default behavior: block everything.
A site is hidden if at least one of the following holds:
hide_fn(msg) is True
or(not expose_fn(msg)) is True
msg["name"] in hide
msg["type"] in hide_types
msg["name"] not in expose and msg["type"] not in expose_types
hide
,hide_types
, andexpose_types
are allNone
For example, suppose the stochastic function fn has two sample sites “a” and “b”. Then any effect outside of
BlockMessenger(fn, hide=["a"])
will not be applied to site “a” and will only see site “b”:>>> def fn(): ... a = pyro.sample("a", dist.Normal(0., 1.)) ... return pyro.sample("b", dist.Normal(a, 1.)) >>> fn_inner = pyro.poutine.trace(fn) >>> fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"])) >>> trace_inner = fn_inner.get_trace() >>> trace_outer = fn_outer.get_trace() >>> "a" in trace_inner True >>> "a" in trace_outer False >>> "b" in trace_inner True >>> "b" in trace_outer True
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
hide_fn – function that takes a site and returns True to hide the site or False/None to expose it. If specified, all other parameters are ignored. Only specify one of hide_fn or expose_fn, not both.
expose_fn – function that takes a site and returns True to expose the site or False/None to hide it. If specified, all other parameters are ignored. Only specify one of hide_fn or expose_fn, not both.
hide_all (bool) – hide all sites
expose_all (bool) – expose all sites normally
hide (list) – list of site names to hide
expose (list) – list of site names to be exposed while all others hidden
hide_types (list) – list of site types to be hidden
expose_types (list) – list of site types to be exposed while all others hidden
- Returns
stochastic function decorated with a
BlockMessenger
- broadcast(fn: None = None) pyro.poutine.broadcast_messenger.BroadcastMessenger [source]¶
- broadcast(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
BroadcastMessenger
Automatically broadcasts the batch shape of the stochastic function at a sample site when inside a single or nested plate context. The existing batch_shape must be broadcastable with the size of the
plate
contexts installed in the cond_indep_stack.Notice how model_automatic_broadcast below automates expanding of distribution batch shapes. This makes it easy to modularize a Pyro model as the sub-components are agnostic of the wrapping
plate
contexts.>>> def model_broadcast_by_hand(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5) ... .expand_by(100)) ... assert sample.shape == torch.Size((100, 3)) ... return sample
>>> @poutine.broadcast ... def model_automatic_broadcast(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5))) ... assert sample.shape == torch.Size((100, 3)) ... return sample
- collapse(fn: None = None, *args: Any, **kwargs: Any) pyro.poutine.collapse_messenger.CollapseMessenger [source]¶
- collapse(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], *args: Any, **kwargs: Any) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
CollapseMessenger
EXPERIMENTAL Collapses all sites in the context by lazily sampling and attempting to use conjugacy relations. If no conjugacy is known this will fail. Code using the results of sample sites must be written to accept Funsors rather than Tensors. This requires
funsor
to be installed.Warning
This is not compatible with automatic guessing of
max_plate_nesting
. If any plates appear within the collapsed context, you should manually declaremax_plate_nesting
to your inference algorithm (e.g.Trace_ELBO(max_plate_nesting=1)
).
- condition(data: Union[Dict[str, torch.Tensor], Trace]) pyro.poutine.condition_messenger.ConditionMessenger [source]¶
- condition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Union[Dict[str, torch.Tensor], Trace]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
ConditionMessenger
Given a stochastic function with some sample statements and a dictionary of observations at names, change the sample statements at those names into observes with those values.
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
To observe a value for site z, we can write
>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})
This is equivalent to adding obs=value as a keyword argument to pyro.sample(“z”, …) in model.
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
data – a dict or a
Trace
- Returns
stochastic function decorated with a
ConditionMessenger
- do(data: Dict[str, Union[torch.Tensor, numbers.Number]]) pyro.poutine.do_messenger.DoMessenger [source]¶
- do(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, Union[torch.Tensor, numbers.Number]]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
DoMessenger
Given a stochastic function with some sample statements and a dictionary of values at names, set the return values of those sites equal to the values as if they were hard-coded to those values and introduce fresh sample sites with the same names whose values do not propagate.
Composes freely with
condition()
to represent counterfactual distributions over potential outcomes. See Single World Intervention Graphs [1] for additional details and theory.Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
To intervene with a value for site z, we can write
>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})
This is equivalent to replacing z = pyro.sample(“z”, …) with z = torch.tensor(1.) and introducing a fresh sample site pyro.sample(“z”, …) whose value is not used elsewhere.
References
- [1] Single World Intervention Graphs: A Primer,
Thomas Richardson, James Robins
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
data – a
dict
mapping sample site names to interventions
- Returns
stochastic function decorated with a
DoMessenger
- enum(fn: None = None, first_available_dim: Optional[int] = None) pyro.poutine.enum_messenger.EnumMessenger [source]¶
- enum(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], first_available_dim: Optional[int] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
EnumMessenger
Enumerates in parallel over discrete sample sites marked
infer={"enumerate": "parallel"}
.- Parameters
first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer or None.
- equalize(sites: Union[str, List[str]], type: Optional[str], keep_dist: Optional[bool]) pyro.poutine.equalize_messenger.EqualizeMessenger [source]¶
- equalize(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], sites: Union[str, List[str]], type: Optional[str], keep_dist: Optional[bool]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
EqualizeMessenger
Given a stochastic function with some primitive statements and a list of names, force the primitive statements at those names to have the same value, with that value being the result of the first primitive statement matching those names.
Consider the following Pyro program:
>>> def per_category_model(category): ... shift = pyro.param(f'{category}_shift', torch.randn(1)) ... mean = pyro.sample(f'{category}_mean', pyro.distributions.Normal(0, 1)) ... std = pyro.sample(f'{category}_std', pyro.distributions.LogNormal(0, 1)) ... return pyro.sample(f'{category}_values', pyro.distributions.Normal(mean + shift, std))
Running the program for multiple categories can be done by
>>> def model(categories): ... return {category:per_category_model(category) for category in categories}
To make the std sample sites have the same value, we can write
>>> equal_std_model = pyro.poutine.equalize(model, '.+_std')
If on top of the above we would like to make the ‘shift’ parameters identical, we can write
>>> equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param')
Alternatively, the
equalize
messenger can be used to condition a model on primitive statements having the same value by setting keep_dist to True. Consider the below model:>>> def model(): ... x = pyro.sample('x', pyro.distributions.Normal(0, 1)) ... y = pyro.sample('y', pyro.distributions.Normal(5, 3)) ... return x, y
The model can be conditioned on ‘x’ and ‘y’ having the same value by
>>> conditioned_model = pyro.poutine.equalize(model, ['x', 'y'], keep_dist=True)
Note that the conditioned model defined above calculates the correct unnormalized log-probablity density, but in order to correctly sample from it one must use SVI or MCMC techniques.
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
sites – a string or list of strings to match site names (the strings can be regular expressions)
type – a string specifying the site type (default is ‘sample’)
keep_dist (bool) – Whether to keep the distributions of the second and subsequent matching primitive statements. If kept this is equivalent to conditioning the model on all matching primitive statements having the same value, as opposed to having the second and subsequent matching primitive statements replaced by delta sampling functions. Defaults to False.
- Returns
stochastic function decorated with a
EqualizeMessenger
- escape(escape_fn: Callable[[Message], bool]) pyro.poutine.escape_messenger.EscapeMessenger [source]¶
- escape(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], escape_fn: Callable[[Message], bool]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
EscapeMessenger
Messenger that does a nonlocal exit by raising a util.NonlocalExit exception
- infer_config(config_fn: Callable[[Message], InferDict]) pyro.poutine.infer_config_messenger.InferConfigMessenger [source]¶
- infer_config(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], config_fn: Callable[[Message], InferDict]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
InferConfigMessenger
Given a callable fn that contains Pyro primitive calls and a callable config_fn taking a trace site and returning a dictionary, updates the value of the infer kwarg at a sample site to config_fn(site).
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
config_fn – a callable taking a site and returning an infer dict
- Returns
stochastic function decorated with
InferConfigMessenger
- lift(prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]]) pyro.poutine.lift_messenger.LiftMessenger [source]¶
- lift(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
LiftMessenger
Given a stochastic function with param calls and a prior distribution, create a stochastic function where all param calls are replaced by sampling from prior. Prior should be a callable or a dict of names to callables.
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2 >>> lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})
lift
makesparam
statements behave likesample
statements using the distributions inprior
. In this example, site s will now behave as if it was replaced withs = pyro.sample("s", dist.Exponential(0.3))
:>>> tr = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> tr.nodes["s"]["type"] == "sample" True >>> tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all()) False
- Parameters
fn – function whose parameters will be lifted to random values
prior – prior function in the form of a Distribution or a dict of stochastic fns
- Returns
fn
decorated with aLiftMessenger
- markov(fn: None = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) pyro.poutine.markov_messenger.MarkovMessenger [source]¶
- markov(fn: Iterable[int] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) pyro.poutine.markov_messenger.MarkovMessenger
- markov(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Markov dependency declaration.
This can be used in a variety of ways:
as a context manager
as a decorator for recursive functions
as an iterator for markov chains
- Parameters
history (int) – The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to
pyro.plate
.keep (bool) – If true, frames are replayable. This is important when branching: if
keep=True
, neighboring branches at the same level can depend on each other; ifkeep=False
, neighboring branches are independent (conditioned on their share”dim (int) – An optional dimension to use for this independence index. Interface stub, behavior not yet implemented.
name (str) – An optional unique name to help inference algorithms match
pyro.markov()
sites between models and guides. Interface stub, behavior not yet implemented.
- mask(mask: Union[bool, torch.BoolTensor]) pyro.poutine.mask_messenger.MaskMessenger [source]¶
- mask(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], mask: Union[bool, torch.BoolTensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
MaskMessenger
Given a stochastic function with some batched sample statements and masking tensor, mask out some of the sample statements elementwise.
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
mask (torch.BoolTensor) – a
{0,1}
-valued masking tensor (1 includes a site, 0 excludes a site)
- Returns
stochastic function decorated with a
MaskMessenger
- queue(fn=None, queue=None, max_tries=None, extend_fn=None, escape_fn=None, num_samples=None)[source]¶
Used in sequential enumeration over discrete variables.
Given a stochastic function and a queue, return a return value from a complete trace in the queue.
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
queue – a queue data structure like multiprocessing.Queue to hold partial traces
max_tries – maximum number of attempts to compute a single complete trace
extend_fn – function (possibly stochastic) that takes a partial trace and a site, and returns a list of extended traces
escape_fn – function (possibly stochastic) that takes a partial trace and a site, and returns a boolean value to decide whether to exit
num_samples – optional number of extended traces for extend_fn to return
- Returns
stochastic function decorated with poutine logic
- reparam(config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]]) pyro.poutine.reparam_messenger.ReparamMessenger [source]¶
- reparam(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]]) pyro.poutine.reparam_messenger.ReparamHandler[pyro.poutine.handlers._P, pyro.poutine.handlers._T]
Convenient wrapper of
ReparamMessenger
Reparametrizes each affected sample site into one or more auxiliary sample sites followed by a deterministic transformation [1].
To specify reparameterizers, pass a
config
dict or callable to the constructor. See thepyro.infer.reparam
module for available reparameterizers.Note some reparameterizers can examine the
*args,**kwargs
inputs of functions they affect; these reparameterizers require usingpoutine.reparam
as a decorator rather than as a context manager.- [1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
“Automatic Reparameterisation of Probabilistic Programs” https://arxiv.org/pdf/1906.03028.pdf
- Parameters
config (dict or callable) – Configuration, either a dict mapping site name to
Reparameterizer
, or a function mapping site toReparam
or None. Seepyro.infer.reparam.strategies
for built-in configuration strategies.
- replay(fn: None = None, trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None) pyro.poutine.replay_messenger.ReplayMessenger [source]¶
- replay(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
ReplayMessenger
Given a callable that contains Pyro primitive calls, return a callable that runs the original, reusing the values at sites in trace at those sites in the new trace
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
replay
makessample
statements behave as if they had sampled the values at the corresponding sites in the trace:>>> old_trace = pyro.poutine.trace(model).get_trace(1.0) >>> replayed_model = pyro.poutine.replay(model, trace=old_trace) >>> bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"]) True
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
trace – a
Trace
data structure to replay againstparams – dict of names of param sites and constrained values in fn to replay against
- Returns
a stochastic function decorated with a
ReplayMessenger
- scale(scale: Union[float, torch.Tensor]) pyro.poutine.scale_messenger.ScaleMessenger [source]¶
- scale(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], scale: Union[float, torch.Tensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
ScaleMessenger
Given a stochastic function with some sample statements and a positive scale factor, scale the score of all sample and observe sites in the function.
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))
scale
multiplicatively scales the log-probabilities of sample sites:>>> scaled_model = pyro.poutine.scale(model, scale=0.5) >>> scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0) >>> unscaled_tr = pyro.poutine.trace(model).get_trace(0.0) >>> bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all()) True
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
scale – a positive scaling factor
- Returns
stochastic function decorated with a
ScaleMessenger
- seed(rng_seed: int) pyro.poutine.seed_messenger.SeedMessenger [source]¶
- seed(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], rng_seed: int) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
SeedMessenger
Handler to set the random number generator to a pre-defined state by setting its seed. This is the same as calling
pyro.set_rng_seed()
before the call to fn. This handler has no additional effect on primitive statements on the standard Pyro backend, but it might interceptpyro.sample
calls in other backends. e.g. the NumPy backend.- Parameters
fn – a stochastic function (callable containing Pyro primitive calls).
rng_seed (int) – rng seed.
- substitute(data: Dict[str, torch.Tensor]) pyro.poutine.substitute_messenger.SubstituteMessenger [source]¶
- substitute(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, torch.Tensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
SubstituteMessenger
Given a stochastic function with param calls and a set of parameter values, create a stochastic function where all param calls are substituted with the fixed values. data should be a dict of names to values. Consider the following Pyro program:
>>> def model(x): ... a = pyro.param("a", torch.tensor(0.5)) ... x = pyro.sample("x", dist.Bernoulli(probs=a)) ... return x >>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})
In this example, site a will now have value torch.tensor(0.3). :param data: dictionary of values keyed by site names. :returns:
fn
decorated with aSubstituteMessenger
- trace(fn: None = None, graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None) pyro.poutine.trace_messenger.TraceMessenger [source]¶
- trace(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None) pyro.poutine.trace_messenger.TraceHandler[pyro.poutine.handlers._P, pyro.poutine.handlers._T]
Convenient wrapper of
TraceMessenger
Return a handler that records the inputs and outputs of primitive calls and their dependencies.
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
We can record its execution using
trace
and use the resulting data structure to compute the log-joint probability of all of the sample sites in the execution or extract all parameters.>>> trace = pyro.poutine.trace(model).get_trace(0.0) >>> logp = trace.log_prob_sum() >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
graph_type – string that specifies the kind of graph to construct
param_only – if true, only records params and not samples
- Returns
stochastic function decorated with a
TraceMessenger
- uncondition(fn: None = None) pyro.poutine.uncondition_messenger.UnconditionMessenger [source]¶
- uncondition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
Convenient wrapper of
UnconditionMessenger
Messenger to force the value of observed nodes to be sampled from their distribution, ignoring observations.
- config_enumerate(guide=None, default='parallel', expand=False, num_samples=None, tmc='diagonal')[source]¶
Configures enumeration for all relevant sites in a guide. This is mainly used in conjunction with
TraceEnum_ELBO
.When configuring for exhaustive enumeration of discrete variables, this configures all sample sites whose distribution satisfies
.has_enumerate_support == True
. When configuring for local parallel Monte Carlo sampling viadefault="parallel", num_samples=n
, this configures all sample sites. This does not overwrite existing annotationsinfer={"enumerate": ...}
.This can be used as either a function:
guide = config_enumerate(guide)
or as a decorator:
@config_enumerate def guide1(*args, **kwargs): ... @config_enumerate(default="sequential", expand=True) def guide2(*args, **kwargs): ...
- Parameters
guide (callable) – a pyro model that will be used as a guide in
SVI
.default (str) – Which enumerate strategy to use, one of “sequential”, “parallel”, or None. Defaults to “parallel”.
expand (bool) – Whether to expand enumerated sample values. See
enumerate_support()
for details. This only applies to exhaustive enumeration, wherenum_samples=None
. Ifnum_samples
is notNone
, then this samples will always be expanded.num_samples (int or None) – if not
None
, use local Monte Carlo sampling rather than exhaustive enumeration. This makes sense for both continuous and discrete distributions.tmc (string or None) – “mixture” or “diagonal” strategies to use in Tensor Monte Carlo
- Returns
an annotated guide
- Return type
callable
Trace¶
- class Trace(graph_type: Literal['flat', 'dense'] = 'flat')[source]¶
Bases:
object
Graph data structure denoting the relationships amongst different pyro primitives in the execution trace.
An execution trace of a Pyro program is a record of every call to
pyro.sample()
andpyro.param()
in a single execution of that program. Traces are directed graphs whose nodes represent primitive calls or input/output, and whose edges represent conditional dependence relationships between those primitive calls. They are created and populated bypoutine.trace
.Each node (or site) in a trace contains the name, input and output value of the site, as well as additional metadata added by inference algorithms or user annotation. In the case of
pyro.sample
, the trace also includes the stochastic function at the site, and any observed data added by users.Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
We can record its execution using
pyro.poutine.trace
and use the resulting data structure to compute the log-joint probability of all of the sample sites in the execution or extract all parameters.>>> trace = pyro.poutine.trace(model).get_trace(0.0) >>> logp = trace.log_prob_sum() >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
We can also inspect or manipulate individual nodes in the trace.
trace.nodes
contains acollections.OrderedDict
of site names and metadata corresponding tox
,s
,z
, and the return value:>>> list(name for name in trace.nodes.keys()) ["_INPUT", "s", "z", "_RETURN"]
Values of
trace.nodes
are dictionaries of node metadata:>>> trace.nodes["z"] {'type': 'sample', 'name': 'z', 'is_observed': False, 'fn': Normal(), 'value': tensor(0.6480), 'args': (), 'kwargs': {}, 'infer': {}, 'scale': 1.0, 'cond_indep_stack': (), 'done': True, 'stop': False, 'continuation': None}
'infer'
is a dictionary of user- or algorithm-specified metadata.'args'
and'kwargs'
are the arguments passed viapyro.sample
tofn.__call__
orfn.log_prob
.'scale'
is used to scale the log-probability of the site when computing the log-joint.'cond_indep_stack'
contains data structures corresponding topyro.plate
contexts appearing in the execution.'done'
,'stop'
, and'continuation'
are only used by Pyro’s internals.- Parameters
graph_type (string) – string specifying the kind of trace graph to construct
- add_node(site_name: str, **kwargs: Any) None [source]¶
- Parameters
site_name (string) – the name of the site to be added
Adds a site to the trace.
Raises an error when attempting to add a duplicate node instead of silently overwriting.
- compute_log_prob(site_filter: Callable[[str, Message], bool] = <function allow_all_sites>) None [source]¶
Compute the site-wise log probabilities of the trace. Each
log_prob
has shape equal to the correspondingbatch_shape
. Eachlog_prob_sum
is a scalar. Both computations are memoized.
- compute_score_parts() None [source]¶
Compute the batched local score parts at each site of the trace. Each
log_prob
has shape equal to the correspondingbatch_shape
. Eachlog_prob_sum
is a scalar. All computations are memoized.
- copy() pyro.poutine.trace_struct.Trace [source]¶
Makes a shallow copy of self with nodes and edges preserved.
- format_shapes(title: str = 'Trace Shapes:', last_site: Optional[str] = None) str [source]¶
Returns a string showing a table of the shapes of all sites in the trace.
- iter_stochastic_nodes() Iterator[Tuple[str, Message]] [source]¶
- Returns
an iterator over stochastic nodes in the trace.
- log_prob_sum(site_filter: Callable[[str, Message], bool] = <function allow_all_sites>) Union[torch.Tensor, float] [source]¶
Compute the site-wise log probabilities of the trace. Each
log_prob
has shape equal to the correspondingbatch_shape
. Eachlog_prob_sum
is a scalar. The computation oflog_prob_sum
is memoized.- Returns
total log probability.
- Return type
- property nonreparam_stochastic_nodes: List[str]¶
a list of names of sample sites whose stochastic functions are not reparameterizable primitive distributions
- Type
return
- pack_tensors(plate_to_symbol: Optional[Dict[str, str]] = None) None [source]¶
Computes packed representations of tensors in the trace. This should be called after
compute_log_prob()
orcompute_score_parts()
.
- property reparameterized_nodes: List[str]¶
a list of names of sample sites whose stochastic functions are reparameterizable primitive distributions
- Type
return
Runtime¶
- class InferDict[source]¶
Bases:
typing_extensions.TypedDict
A dictionary that contains information about inference.
This can be used to configure per-site inference strategies, e.g.:
pyro.sample( "x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}, )
- Keys:
- enumerate (str):
If one of the strings “sequential” or “parallel”, enables enumeration. Parallel enumeration is generally faster but requires broadcasting-safe operations and static structure.
- expand (bool):
Whether to expand the distribution during enumeration. Defaults to False if missing.
- is_auxiliary (bool):
Whether the sample site is auxiliary, e.g. for use in guides that deterministically transform auxiliary variables. Defaults to False if missing.
- is_observed (bool):
Whether the sample site is observed (i.e. not latent). Defaults to False if missing.
- num_samples (int):
The number of samples to draw. Defaults to 1 if missing.
- obs (optional torch.Tensor):
The observed value, or None for latent variables. Defaults to None if missing.
- prior (optional torch.distributions.Distribution):
(internal) For use in GuideMessenger to store the model’s prior distribution (conditioned on upstream sites).
- tmc (str):
Whether to use the diagonal or mixture approximation for Tensor Monte Carlo in TraceTMC_ELBO.
- was_observed (bool):
(internal) Whether the sample site was originally observed, in the context of inference via Reweighted Wake Sleep or Compiled Sequential Importance Sampling.
- enumerate: typing_extensions.Literal[sequential, parallel]¶
- obs: Optional[torch.Tensor]¶
- prior: TorchDistributionMixin¶
- tmc: typing_extensions.Literal[diagonal, mixture]¶
- class Message[source]¶
Bases:
typing_extensions.TypedDict
,Generic
[pyro.poutine.runtime._P
,pyro.poutine.runtime._T
]Pyro’s internal message type for effect handling.
Messages are stored in trace objects, e.g.:
trace.nodes["my_site_name"] # This is a Message.
- Keys:
- type (str):
The message type, typically one of the strings “sample”, “param”, “plate”, or “markov”, but possibly custom.
- name (str):
The site name, typically naming a sample or parameter.
- fn (callable):
The distribution or function used to generate the sample.
- is_observed (bool):
A flag to indicate whether the value is observed.
- args (tuple):
Positional arguments to the distribution or function.
- kwargs (dict):
Keyword arguments to the distribution or function.
- value (torch.Tensor):
The value of the sample (either observed or sampled).
- scale (torch.Tensor):
A scaling factor for the log probability.
- mask (bool torch.Tensor):
A bool or tensor to mask the log probability.
- cond_indep_stack (tuple):
The site’s local stack of conditional independence metadata. Immutable.
- done (bool):
A flag to indicate whether the message has been handled.
- stop (bool):
A flag to stop further processing of the message.
- continuation (callable):
A function to call after processing the message.
- infer (optional InferDict):
A dictionary of inference parameters.
- obs (torch.Tensor):
The observed value.
- log_prob (torch.Tensor):
The log probability of the sample.
- log_prob_sum (torch.Tensor):
The sum of the log probability.
- unscaled_log_prob (torch.Tensor):
The unscaled log probability.
- score_parts (pyro.distributions.ScoreParts):
A collection of score parts.
- packed (Message):
A packed message, used during enumeration.
- args: Tuple¶
- cond_indep_stack: Tuple[CondIndepStackFrame, ...]¶
- fn: Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T]¶
- infer: Optional[pyro.poutine.runtime.InferDict]¶
- kwargs: Dict¶
- log_prob: torch.Tensor¶
- log_prob_sum: torch.Tensor¶
- mask: Optional[Union[bool, torch.Tensor]]¶
- obs: Optional[torch.Tensor]¶
- scale: Union[torch.Tensor, float]¶
- score_parts: ScoreParts¶
- unscaled_log_prob: torch.Tensor¶
- value: Optional[pyro.poutine.runtime._T]¶
- exception NonlocalExit(site: pyro.poutine.runtime.Message, *args, **kwargs)[source]¶
Bases:
Exception
Exception for exiting nonlocally from poutine execution.
Used by poutine.EscapeMessenger to return site information.
- am_i_wrapped() bool [source]¶
Checks whether the current computation is wrapped in a poutine. :returns: bool
- apply_stack(initial_msg: pyro.poutine.runtime.Message) None [source]¶
Execute the effect stack at a single site according to the following scheme:
For each
Messenger
in the stack from bottom to top, executeMessenger._process_message
with the message; if the message field “stop” is True, stop; otherwise, continueApply default behavior (
default_process_message
) to finish remaining site executionFor each
Messenger
in the stack from top to bottom, execute_postprocess_message
to update the message and internal messenger state with the site resultsIf the message field “continuation” is not
None
, call it with the message
- Parameters
initial_msg (dict) – the starting version of the trace site
- Returns
None
- default_process_message(msg: pyro.poutine.runtime.Message) None [source]¶
Default method for processing messages in inference.
- Parameters
msg – a message to be processed
- Returns
None
- effectful(fn: None = None, type: Optional[str] = None) Callable[[Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T]], Callable[[...], pyro.poutine.runtime._T]] [source]¶
- effectful(fn: Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T] = None, type: Optional[str] = None) Callable[[...], pyro.poutine.runtime._T]
- Parameters
fn – function or callable that performs an effectful computation
type (str) – the type label of the operation, e.g. “sample”
Wrapper for calling
apply_stack()
to apply any active effects.
- get_mask() Optional[Union[bool, torch.Tensor]] [source]¶
Records the effects of enclosing
poutine.mask
handlers.This is useful for avoiding expensive
pyro.factor()
computations during prediction, when the log density need not be computed, e.g.:def model(): # ... if poutine.get_mask() is not False: log_density = my_expensive_computation() pyro.factor("foo", log_density) # ...
- Returns
The mask.
- Return type
None, bool, or torch.Tensor
- get_plates() Tuple[CondIndepStackFrame, ...] [source]¶
Records the effects of enclosing
pyro.plate
contexts.- Returns
A tuple of
pyro.poutine.indep_messenger.CondIndepStackFrame
objects.- Return type
Utilities¶
- all_escape(trace: Trace, msg: Message) bool [source]¶
- Parameters
trace – a partial trace
msg – the message at a Pyro primitive site
- Returns
boolean decision value
Utility function that checks if a site is not already in a trace.
Used by EscapeMessenger to decide whether to do a nonlocal exit at a site. Subroutine for approximately integrating out variables for variance reduction.
- discrete_escape(trace: Trace, msg: Message) bool [source]¶
- Parameters
trace – a partial trace
msg – the message at a Pyro primitive site
- Returns
boolean decision value
Utility function that checks if a sample site is discrete and not already in a trace.
Used by EscapeMessenger to decide whether to do a nonlocal exit at a site. Subroutine for integrating out discrete variables for variance reduction.
- enum_extend(trace: Trace, msg: Message, num_samples: Optional[int] = None) List[Trace] [source]¶
- Parameters
trace – a partial trace
msg – the message at a Pyro primitive site
num_samples – maximum number of extended traces to return.
- Returns
a list of traces, copies of input trace with one extra site
Utility function to copy and extend a trace with sites based on the input site whose values are enumerated from the support of the input site’s distribution.
Used for exact inference and integrating out discrete variables.
- mc_extend(trace: Trace, msg: Message, num_samples: Optional[int] = None) List[Trace] [source]¶
- Parameters
trace – a partial trace
msg – the message at a Pyro primitive site
num_samples – maximum number of extended traces to return.
- Returns
a list of traces, copies of input trace with one extra site
Utility function to copy and extend a trace with sites based on the input site whose values are sampled from the input site’s function.
Used for Monte Carlo marginalization of individual sample sites.
- prune_subsample_sites(trace: Trace) Trace [source]¶
Copies and removes all subsample sites from a trace.
Messengers¶
Messenger objects contain the implementations of the effects exposed by handlers. Advanced users may modify the implementations of messengers behind existing handlers or write new messengers that implement new effects and compose correctly with the rest of the library.
Messenger¶
- class Messenger[source]¶
Bases:
object
Context manager class that modifies behavior and adds side effects to stochastic functions i.e. callables containing Pyro primitive statements.
This is the base Messenger class. It implements the default behavior for all Pyro primitives, so that the joint distribution induced by a stochastic function fn is identical to the joint distribution induced by
Messenger()(fn)
.Class of transformers for messages passed during inference. Most inference operations are implemented in subclasses of this.
- classmethod register(fn: Optional[Callable] = None, type: Optional[str] = None, post: Optional[bool] = None) Callable [source]¶
- Parameters
fn – function implementing operation
type (str) – name of the operation (also passed to
effectful()
)post (bool) – if True, use this operation as postprocess
Dynamically add operations to an effect. Useful for generating wrappers for libraries.
Example:
@SomeMessengerClass.register def some_function(msg) ...do_something... return msg
- classmethod unregister(fn: Optional[Callable] = None, type: Optional[str] = None) Optional[Callable] [source]¶
- Parameters
fn – function implementing operation
type (str) – name of the operation (also passed to
effectful()
)
Dynamically remove operations from an effect. Useful for removing wrappers from libraries.
Example:
SomeMessengerClass.unregister(some_function, "name")
- block_messengers(predicate: Callable[[pyro.poutine.messenger.Messenger], bool]) Iterator[List[pyro.poutine.messenger.Messenger]] [source]¶
EXPERIMENTAL Context manager to temporarily remove matching messengers from the _PYRO_STACK. Note this does not call the
.__exit__()
and.__enter__()
methods.This is useful to selectively block enclosing handlers.
- Parameters
predicate (callable) – A predicate mapping messenger instance to boolean. This mutes all messengers
m
for whichbool(predicate(m)) is True
.- Yields
A list of matched messengers that are blocked.
BlockMessenger¶
- class BlockMessenger(hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None)[source]¶
Bases:
pyro.poutine.messenger.Messenger
This handler selectively hides Pyro primitive sites from the outside world. Default behavior: block everything.
A site is hidden if at least one of the following holds:
hide_fn(msg) is True
or(not expose_fn(msg)) is True
msg["name"] in hide
msg["type"] in hide_types
msg["name"] not in expose and msg["type"] not in expose_types
hide
,hide_types
, andexpose_types
are allNone
For example, suppose the stochastic function fn has two sample sites “a” and “b”. Then any effect outside of
BlockMessenger(fn, hide=["a"])
will not be applied to site “a” and will only see site “b”:>>> def fn(): ... a = pyro.sample("a", dist.Normal(0., 1.)) ... return pyro.sample("b", dist.Normal(a, 1.)) >>> fn_inner = pyro.poutine.trace(fn) >>> fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"])) >>> trace_inner = fn_inner.get_trace() >>> trace_outer = fn_outer.get_trace() >>> "a" in trace_inner True >>> "a" in trace_outer False >>> "b" in trace_inner True >>> "b" in trace_outer True
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
hide_fn – function that takes a site and returns True to hide the site or False/None to expose it. If specified, all other parameters are ignored. Only specify one of hide_fn or expose_fn, not both.
expose_fn – function that takes a site and returns True to expose the site or False/None to hide it. If specified, all other parameters are ignored. Only specify one of hide_fn or expose_fn, not both.
hide_all (bool) – hide all sites
expose_all (bool) – expose all sites normally
hide (list) – list of site names to hide
expose (list) – list of site names to be exposed while all others hidden
hide_types (list) – list of site types to be hidden
expose_types (list) – list of site types to be exposed while all others hidden
- Returns
stochastic function decorated with a
BlockMessenger
BroadcastMessenger¶
- class BroadcastMessenger[source]¶
Bases:
pyro.poutine.messenger.Messenger
Automatically broadcasts the batch shape of the stochastic function at a sample site when inside a single or nested plate context. The existing batch_shape must be broadcastable with the size of the
plate
contexts installed in the cond_indep_stack.Notice how model_automatic_broadcast below automates expanding of distribution batch shapes. This makes it easy to modularize a Pyro model as the sub-components are agnostic of the wrapping
plate
contexts.>>> def model_broadcast_by_hand(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5) ... .expand_by(100)) ... assert sample.shape == torch.Size((100, 3)) ... return sample
>>> @poutine.broadcast ... def model_automatic_broadcast(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5))) ... assert sample.shape == torch.Size((100, 3)) ... return sample
CollapseMessenger¶
- class CollapseMessenger(*args: Any, **kwargs: Any)[source]¶
Bases:
pyro.poutine.trace_messenger.TraceMessenger
EXPERIMENTAL Collapses all sites in the context by lazily sampling and attempting to use conjugacy relations. If no conjugacy is known this will fail. Code using the results of sample sites must be written to accept Funsors rather than Tensors. This requires
funsor
to be installed.Warning
This is not compatible with automatic guessing of
max_plate_nesting
. If any plates appear within the collapsed context, you should manually declaremax_plate_nesting
to your inference algorithm (e.g.Trace_ELBO(max_plate_nesting=1)
).
ConditionMessenger¶
- class ConditionMessenger(data: Union[Dict[str, torch.Tensor], pyro.poutine.trace_struct.Trace])[source]¶
Bases:
pyro.poutine.messenger.Messenger
Given a stochastic function with some sample statements and a dictionary of observations at names, change the sample statements at those names into observes with those values.
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
To observe a value for site z, we can write
>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})
This is equivalent to adding obs=value as a keyword argument to pyro.sample(“z”, …) in model.
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
data – a dict or a
Trace
- Returns
stochastic function decorated with a
ConditionMessenger
DoMessenger¶
- class DoMessenger(data: Dict[str, Union[torch.Tensor, numbers.Number]])[source]¶
Bases:
pyro.poutine.messenger.Messenger
Given a stochastic function with some sample statements and a dictionary of values at names, set the return values of those sites equal to the values as if they were hard-coded to those values and introduce fresh sample sites with the same names whose values do not propagate.
Composes freely with
condition()
to represent counterfactual distributions over potential outcomes. See Single World Intervention Graphs [1] for additional details and theory.Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
To intervene with a value for site z, we can write
>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})
This is equivalent to replacing z = pyro.sample(“z”, …) with z = torch.tensor(1.) and introducing a fresh sample site pyro.sample(“z”, …) whose value is not used elsewhere.
References
- [1] Single World Intervention Graphs: A Primer,
Thomas Richardson, James Robins
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
data – a
dict
mapping sample site names to interventions
- Returns
stochastic function decorated with a
DoMessenger
EnumMessenger¶
- class EnumMessenger(first_available_dim: Optional[int] = None)[source]¶
Bases:
pyro.poutine.messenger.Messenger
Enumerates in parallel over discrete sample sites marked
infer={"enumerate": "parallel"}
.- Parameters
first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer or None.
- enumerate_site(msg: pyro.poutine.runtime.Message) torch.Tensor [source]¶
EqualizeMessenger¶
- class EqualizeMessenger(sites: Union[str, List[str]], type: Optional[str] = 'sample', keep_dist: Optional[bool] = False)[source]¶
Bases:
pyro.poutine.messenger.Messenger
Given a stochastic function with some primitive statements and a list of names, force the primitive statements at those names to have the same value, with that value being the result of the first primitive statement matching those names.
Consider the following Pyro program:
>>> def per_category_model(category): ... shift = pyro.param(f'{category}_shift', torch.randn(1)) ... mean = pyro.sample(f'{category}_mean', pyro.distributions.Normal(0, 1)) ... std = pyro.sample(f'{category}_std', pyro.distributions.LogNormal(0, 1)) ... return pyro.sample(f'{category}_values', pyro.distributions.Normal(mean + shift, std))
Running the program for multiple categories can be done by
>>> def model(categories): ... return {category:per_category_model(category) for category in categories}
To make the std sample sites have the same value, we can write
>>> equal_std_model = pyro.poutine.equalize(model, '.+_std')
If on top of the above we would like to make the ‘shift’ parameters identical, we can write
>>> equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param')
Alternatively, the
equalize
messenger can be used to condition a model on primitive statements having the same value by setting keep_dist to True. Consider the below model:>>> def model(): ... x = pyro.sample('x', pyro.distributions.Normal(0, 1)) ... y = pyro.sample('y', pyro.distributions.Normal(5, 3)) ... return x, y
The model can be conditioned on ‘x’ and ‘y’ having the same value by
>>> conditioned_model = pyro.poutine.equalize(model, ['x', 'y'], keep_dist=True)
Note that the conditioned model defined above calculates the correct unnormalized log-probablity density, but in order to correctly sample from it one must use SVI or MCMC techniques.
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
sites – a string or list of strings to match site names (the strings can be regular expressions)
type – a string specifying the site type (default is ‘sample’)
keep_dist (bool) – Whether to keep the distributions of the second and subsequent matching primitive statements. If kept this is equivalent to conditioning the model on all matching primitive statements having the same value, as opposed to having the second and subsequent matching primitive statements replaced by delta sampling functions. Defaults to False.
- Returns
stochastic function decorated with a
EqualizeMessenger
EscapeMessenger¶
- class EscapeMessenger(escape_fn: Callable[[pyro.poutine.runtime.Message], bool])[source]¶
Bases:
pyro.poutine.messenger.Messenger
Messenger that does a nonlocal exit by raising a util.NonlocalExit exception
IndepMessenger¶
- class IndepMessenger(name: str, size: int, dim: Optional[int] = None, device: Optional[str] = None)[source]¶
Bases:
pyro.poutine.messenger.Messenger
This messenger keeps track of stack of independence information declared by nested
plate
contexts. This information is stored in acond_indep_stack
at each sample/observe site for consumption byTraceMessenger
.Example:
x_axis = IndepMessenger('outer', 320, dim=-1) y_axis = IndepMessenger('inner', 200, dim=-2) with x_axis: x_noise = sample("x_noise", dist.Normal(loc, scale).expand_by([320])) with y_axis: y_noise = sample("y_noise", dist.Normal(loc, scale).expand_by([200, 1])) with x_axis, y_axis: xy_noise = sample("xy_noise", dist.Normal(loc, scale).expand_by([200, 320]))
- property indices: torch.Tensor¶
InferConfigMessenger¶
- class InferConfigMessenger(config_fn: Callable[[Message], InferDict])[source]¶
Bases:
pyro.poutine.messenger.Messenger
Given a callable fn that contains Pyro primitive calls and a callable config_fn taking a trace site and returning a dictionary, updates the value of the infer kwarg at a sample site to config_fn(site).
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
config_fn – a callable taking a site and returning an infer dict
- Returns
stochastic function decorated with
InferConfigMessenger
LiftMessenger¶
- class LiftMessenger(prior: Union[Callable, pyro.distributions.distribution.Distribution, Dict[str, Union[pyro.distributions.distribution.Distribution, Callable]]])[source]¶
Bases:
pyro.poutine.messenger.Messenger
Given a stochastic function with param calls and a prior distribution, create a stochastic function where all param calls are replaced by sampling from prior. Prior should be a callable or a dict of names to callables.
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2 >>> lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})
lift
makesparam
statements behave likesample
statements using the distributions inprior
. In this example, site s will now behave as if it was replaced withs = pyro.sample("s", dist.Exponential(0.3))
:>>> tr = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> tr.nodes["s"]["type"] == "sample" True >>> tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all()) False
- Parameters
fn – function whose parameters will be lifted to random values
prior – prior function in the form of a Distribution or a dict of stochastic fns
- Returns
fn
decorated with aLiftMessenger
MarkovMessenger¶
- class MarkovMessenger(history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None)[source]¶
Bases:
pyro.poutine.reentrant_messenger.ReentrantMessenger
Markov dependency declaration.
This is a statistical equivalent of a memory management arena.
- Parameters
history (int) – The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to
pyro.plate
.keep (bool) – If true, frames are replayable. This is important when branching: if
keep=True
, neighboring branches at the same level can depend on each other; ifkeep=False
, neighboring branches are independent (conditioned on their shared ancestors).dim (int) – An optional dimension to use for this independence index. Interface stub, behavior not yet implemented.
name (str) – An optional unique name to help inference algorithms match
pyro.markov()
sites between models and guides. Interface stub, behavior not yet implemented.
MaskMessenger¶
- class MaskMessenger(mask: Union[bool, torch.BoolTensor])[source]¶
Bases:
pyro.poutine.messenger.Messenger
Given a stochastic function with some batched sample statements and masking tensor, mask out some of the sample statements elementwise.
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
mask (torch.BoolTensor) – a
{0,1}
-valued masking tensor (1 includes a site, 0 excludes a site)
- Returns
stochastic function decorated with a
MaskMessenger
PlateMessenger¶
- class PlateMessenger(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]¶
Bases:
pyro.poutine.subsample_messenger.SubsampleMessenger
Swiss army knife of broadcasting amazingness: combines shape inference, independence annotation, and subsampling
- block_plate(name: Optional[str] = None, dim: Optional[int] = None, *, strict: bool = True) Iterator[None] [source]¶
EXPERIMENTAL Context manager to temporarily block a single enclosing plate.
This is useful for sampling auxiliary variables or lazily sampling global variables that are needed in a plated context. For example the following models are equivalent:
Example:
def model_1(data): loc = pyro.sample("loc", dist.Normal(0, 1)) with pyro.plate("data", len(data)): with block_plate("data"): scale = pyro.sample("scale", dist.LogNormal(0, 1)) pyro.sample("x", dist.Normal(loc, scale)) def model_2(data): loc = pyro.sample("loc", dist.Normal(0, 1)) scale = pyro.sample("scale", dist.LogNormal(0, 1)) with pyro.plate("data", len(data)): pyro.sample("x", dist.Normal(loc, scale))
ReentrantMessenger¶
ReparamMessenger¶
- class ReparamHandler(msngr, fn: Callable[[pyro.poutine.reparam_messenger._P], pyro.poutine.reparam_messenger._T])[source]¶
Bases:
Generic
[pyro.poutine.reparam_messenger._P
,pyro.poutine.reparam_messenger._T
]Reparameterization poutine.
- class ReparamMessenger(config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]])[source]¶
Bases:
pyro.poutine.messenger.Messenger
Reparametrizes each affected sample site into one or more auxiliary sample sites followed by a deterministic transformation [1].
To specify reparameterizers, pass a
config
dict or callable to the constructor. See thepyro.infer.reparam
module for available reparameterizers.Note some reparameterizers can examine the
*args,**kwargs
inputs of functions they affect; these reparameterizers require usingpoutine.reparam
as a decorator rather than as a context manager.- [1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
“Automatic Reparameterisation of Probabilistic Programs” https://arxiv.org/pdf/1906.03028.pdf
- Parameters
config (dict or callable) – Configuration, either a dict mapping site name to
Reparameterizer
, or a function mapping site toReparam
or None. Seepyro.infer.reparam.strategies
for built-in configuration strategies.
ReplayMessenger¶
- class ReplayMessenger(trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None)[source]¶
Bases:
pyro.poutine.messenger.Messenger
Given a callable that contains Pyro primitive calls, return a callable that runs the original, reusing the values at sites in trace at those sites in the new trace
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
replay
makessample
statements behave as if they had sampled the values at the corresponding sites in the trace:>>> old_trace = pyro.poutine.trace(model).get_trace(1.0) >>> replayed_model = pyro.poutine.replay(model, trace=old_trace) >>> bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"]) True
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
trace – a
Trace
data structure to replay againstparams – dict of names of param sites and constrained values in fn to replay against
- Returns
a stochastic function decorated with a
ReplayMessenger
ScaleMessenger¶
- class ScaleMessenger(scale: Union[float, torch.Tensor])[source]¶
Bases:
pyro.poutine.messenger.Messenger
Given a stochastic function with some sample statements and a positive scale factor, scale the score of all sample and observe sites in the function.
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))
scale
multiplicatively scales the log-probabilities of sample sites:>>> scaled_model = pyro.poutine.scale(model, scale=0.5) >>> scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0) >>> unscaled_tr = pyro.poutine.trace(model).get_trace(0.0) >>> bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all()) True
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
scale – a positive scaling factor
- Returns
stochastic function decorated with a
ScaleMessenger
SeedMessenger¶
- class SeedMessenger(rng_seed: int)[source]¶
Bases:
pyro.poutine.messenger.Messenger
Handler to set the random number generator to a pre-defined state by setting its seed. This is the same as calling
pyro.set_rng_seed()
before the call to fn. This handler has no additional effect on primitive statements on the standard Pyro backend, but it might interceptpyro.sample
calls in other backends. e.g. the NumPy backend.- Parameters
fn – a stochastic function (callable containing Pyro primitive calls).
rng_seed (int) – rng seed.
SubsampleMessenger¶
- class SubsampleMessenger(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]¶
Bases:
pyro.poutine.indep_messenger.IndepMessenger
Extension of IndepMessenger that includes subsampling.
SubstituteMessenger¶
- class SubstituteMessenger(data: Dict[str, torch.Tensor])[source]¶
Bases:
pyro.poutine.messenger.Messenger
Given a stochastic function with param calls and a set of parameter values, create a stochastic function where all param calls are substituted with the fixed values. data should be a dict of names to values. Consider the following Pyro program:
>>> def model(x): ... a = pyro.param("a", torch.tensor(0.5)) ... x = pyro.sample("x", dist.Bernoulli(probs=a)) ... return x >>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})
In this example, site a will now have value torch.tensor(0.3). :param data: dictionary of values keyed by site names. :returns:
fn
decorated with aSubstituteMessenger
TraceMessenger¶
- class TraceHandler(msngr: pyro.poutine.trace_messenger.TraceMessenger, fn: Callable[[pyro.poutine.trace_messenger._P], pyro.poutine.trace_messenger._T])[source]¶
Bases:
Generic
[pyro.poutine.trace_messenger._P
,pyro.poutine.trace_messenger._T
]Execution trace poutine.
A TraceHandler records the input and output to every Pyro primitive and stores them as a site in a Trace(). This should, in theory, be sufficient information for every inference algorithm (along with the implicit computational graph in the Variables?)
We can also use this for visualization.
- get_trace(*args, **kwargs) pyro.poutine.trace_struct.Trace [source]¶
- Returns
data structure
- Return type
Helper method for a very common use case. Calls this poutine and returns its trace instead of the function’s return value.
- property trace: pyro.poutine.trace_struct.Trace¶
- class TraceMessenger(graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None)[source]¶
Bases:
pyro.poutine.messenger.Messenger
Return a handler that records the inputs and outputs of primitive calls and their dependencies.
Consider the following Pyro program:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
We can record its execution using
trace
and use the resulting data structure to compute the log-joint probability of all of the sample sites in the execution or extract all parameters.>>> trace = pyro.poutine.trace(model).get_trace(0.0) >>> logp = trace.log_prob_sum() >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
- Parameters
fn – a stochastic function (callable containing Pyro primitive calls)
graph_type – string that specifies the kind of graph to construct
param_only – if true, only records params and not samples
- Returns
stochastic function decorated with a
TraceMessenger
- get_trace() pyro.poutine.trace_struct.Trace [source]¶
- Returns
data structure
- Return type
Helper method for a very common use case. Returns a shallow copy of
self.trace
.
- identify_dense_edges(trace: pyro.poutine.trace_struct.Trace) None [source]¶
Modifies a trace in-place by adding all edges based on the cond_indep_stack information stored at each site.
UnconditionMessenger¶
- class UnconditionMessenger[source]¶
Bases:
pyro.poutine.messenger.Messenger
Messenger to force the value of observed nodes to be sampled from their distribution, ignoring observations.
GuideMessenger¶
- class GuideMessenger(model: Callable)[source]¶
Bases:
pyro.poutine.trace_messenger.TraceMessenger
,abc.ABC
Abstract base class for effect-based guides.
Derived classes must implement the
get_posterior()
method.- property model: Callable¶
- __call__(*args, **kwargs) Dict[str, torch.Tensor] [source]¶
Draws posterior samples from the guide and replays the model against those samples.
- Returns
A dict mapping sample site name to sample value. This includes latent, deterministic, and observed values.
- Return type
- abstract get_posterior(name: str, prior: TorchDistributionMixin) Union[TorchDistributionMixin, torch.Tensor] [source]¶
Abstract method to compute a posterior distribution or sample a posterior value given a prior distribution conditioned on upstream posterior samples.
Implementations may use
pyro.param
andpyro.sample
inside this function, butpyro.sample
statements should setinfer={"is_auxiliary": True"}
.Implementations may access further information for computations:
value = self.upstream_value(name)
is the value of an upstreamsample or deterministic site.
self.trace
is a trace of upstream sites, and may be useful for other information such asself.trace.nodes["my_site"]["fn"]
orself.trace.nodes["my_site"]["cond_indep_stack"]
.args, kwargs = self.args_kwargs
are the inputs to the model, andmay be useful for amortization.
- Parameters
name (str) – The name of the sample site to sample.
prior (Distribution) – The prior distribution of this sample site (conditioned on upstream samples from the posterior).
- Returns
A posterior distribution or sample from the posterior distribution.
- Return type
- upstream_value(name: str) Optional[torch.Tensor] [source]¶
For use in
get_posterior()
.- Returns
The value of an upstream sample or deterministic site
- Return type
- get_traces() Tuple[pyro.poutine.trace_struct.Trace, pyro.poutine.trace_struct.Trace] [source]¶
This can be called after running
__call__()
to extract a pair of traces.In contrast to the trace-replay pattern of generating a pair of traces,
GuideMessenger
interleaves model and guide computations, so only a singleguide(*args, **kwargs)
call is needed to create both traces. This function merely extract the relevant information from this guide’s.trace
attribute.- Returns
a pair
(model_trace, guide_trace)
- Return type