Easy Custom Guides¶
EasyGuide¶
- class EasyGuide(model)[source]¶
Bases:
pyro.nn.module.PyroModuleBase class for “easy guides”, which are more flexible than
AutoGuides, but are easier to write than raw Pyro guides.Derived classes should define a
guide()method. Thisguide()method can combine ordinary guide statements (e.g.pyro.sampleandpyro.param) with the following special statements:group = self.group(...)selects multiplepyro.samplesites in the model. SeeGroupfor subsequent methods.with self.plate(...): ...should be used instead ofpyro.plate.self.map_estimate(...)uses aDeltaguide for a single site.
Derived classes may also override the
init()method to provide custom initialization for models sites.- Parameters
model (callable) – A Pyro model.
- property model¶
- init(site)[source]¶
Model initialization method, may be overridden by user.
This should input a site and output a valid sample from that site. The default behavior is to draw a random sample:
return site["fn"]()
For other possible initialization functions see http://docs.pyro.ai/en/stable/infer.autoguide.html#module-pyro.infer.autoguide.initialization
- forward(*args, **kwargs)[source]¶
Runs the guide. This is typically used by inference algorithms.
Note
This method is used internally by
Module. Users should instead use__call__().
- plate(name, size=None, subsample_size=None, subsample=None, *args, **kwargs)[source]¶
A wrapper around
pyro.plateto allow EasyGuide to automatically construct plates. You should use this rather thanpyro.plateinside yourguide()implementation.
easy_guide¶
- easy_guide(model)[source]¶
Convenience decorator to create an
EasyGuide. The following are equivalent:# Version 1. Decorate a function. @easy_guide(model) def guide(self, foo, bar): return my_guide(foo, bar) # Version 2. Create and instantiate a subclass of EasyGuide. class Guide(EasyGuide): def guide(self, foo, bar): return my_guide(foo, bar) guide = Guide(model)
Note
@easy_guidewrappers cannot be pickled; to build a guide that can be pickled, instead subclass fromEasyGuide.- Parameters
model (callable) – a Pyro model.
Group¶
- class Group(guide, sites)[source]¶
Bases:
objectAn autoguide helper to match a group of model sites.
- Variables
event_shape (torch.Size) – The total flattened concatenated shape of all matching sample sites in the model.
prototype_sites (list) – A list of all matching sample sites in a prototype trace of the model.
- Parameters
- property guide¶
- sample(guide_name, fn, infer=None)[source]¶
Wrapper around
pyro.sample()to create a single auxiliary sample site and then unpack to multiple sample sites for model replay.- Parameters
- Returns
A pair
(guide_z, model_zs)whereguide_zis the single concatenated blob andmodel_zsis a dict mapping site name to constrained model sample.- Return type