Easy Custom Guides¶
EasyGuide¶
- class EasyGuide(model)[source]¶
Bases:
pyro.nn.module.PyroModule
Base class for “easy guides”, which are more flexible than
AutoGuide
s, but are easier to write than raw Pyro guides.Derived classes should define a
guide()
method. Thisguide()
method can combine ordinary guide statements (e.g.pyro.sample
andpyro.param
) with the following special statements:group = self.group(...)
selects multiplepyro.sample
sites in the model. SeeGroup
for subsequent methods.with self.plate(...): ...
should be used instead ofpyro.plate
.self.map_estimate(...)
uses aDelta
guide for a single site.
Derived classes may also override the
init()
method to provide custom initialization for models sites.- Parameters
model (callable) – A Pyro model.
- property model¶
- init(site)[source]¶
Model initialization method, may be overridden by user.
This should input a site and output a valid sample from that site. The default behavior is to draw a random sample:
return site["fn"]()
For other possible initialization functions see http://docs.pyro.ai/en/stable/infer.autoguide.html#module-pyro.infer.autoguide.initialization
- forward(*args, **kwargs)[source]¶
Runs the guide. This is typically used by inference algorithms.
Note
This method is used internally by
Module
. Users should instead use__call__()
.
- plate(name, size=None, subsample_size=None, subsample=None, *args, **kwargs)[source]¶
A wrapper around
pyro.plate
to allow EasyGuide to automatically construct plates. You should use this rather thanpyro.plate
inside yourguide()
implementation.
easy_guide¶
- easy_guide(model)[source]¶
Convenience decorator to create an
EasyGuide
. The following are equivalent:# Version 1. Decorate a function. @easy_guide(model) def guide(self, foo, bar): return my_guide(foo, bar) # Version 2. Create and instantiate a subclass of EasyGuide. class Guide(EasyGuide): def guide(self, foo, bar): return my_guide(foo, bar) guide = Guide(model)
Note
@easy_guide
wrappers cannot be pickled; to build a guide that can be pickled, instead subclass fromEasyGuide
.- Parameters
model (callable) – a Pyro model.
Group¶
- class Group(guide, sites)[source]¶
Bases:
object
An autoguide helper to match a group of model sites.
- Variables
event_shape (torch.Size) – The total flattened concatenated shape of all matching sample sites in the model.
prototype_sites (list) – A list of all matching sample sites in a prototype trace of the model.
- Parameters
- property guide¶
- sample(guide_name, fn, infer=None)[source]¶
Wrapper around
pyro.sample()
to create a single auxiliary sample site and then unpack to multiple sample sites for model replay.- Parameters
- Returns
A pair
(guide_z, model_zs)
whereguide_z
is the single concatenated blob andmodel_zs
is a dict mapping site name to constrained model sample.- Return type