Easy Custom Guides

EasyGuide

class EasyGuide(model)[source]

Bases: pyro.nn.module.PyroModule

Base class for “easy guides”, which are more flexible than AutoGuide s, but are easier to write than raw Pyro guides.

Derived classes should define a guide() method. This guide() method can combine ordinary guide statements (e.g. pyro.sample and pyro.param) with the following special statements:

  • group = self.group(...) selects multiple pyro.sample sites in the model. See Group for subsequent methods.
  • with self.plate(...): ... should be used instead of pyro.plate.
  • self.map_estimate(...) uses a Delta guide for a single site.

Derived classes may also override the init() method to provide custom initialization for models sites.

Parameters:model (callable) – A Pyro model.
model
guide(*args, **kargs)[source]

Guide implementation, to be overridden by user.

init(site)[source]

Model initialization method, may be overridden by user.

This should input a site and output a valid sample from that site. The default behavior is to draw a random sample:

return site["fn"]()

For other possible initialization functions see http://docs.pyro.ai/en/stable/infer.autoguide.html#module-pyro.infer.autoguide.initialization

forward(*args, **kwargs)[source]

Runs the guide. This is typically used by inference algorithms.

plate(name, size=None, subsample_size=None, subsample=None, *args, **kwargs)[source]

A wrapper around pyro.plate to allow EasyGuide to automatically construct plates. You should use this rather than pyro.plate inside your guide() implementation.

group(match='.*')[source]

Select a Group of model sites for joint guidance.

Parameters:match (str) – A regex string matching names of model sample sites.
Returns:A group of model sites.
Return type:Group
map_estimate(name)[source]

Construct a maximum a posteriori (MAP) guide using Delta distributions.

Parameters:name (str) – The name of a model sample site.
Returns:A sampled value.
Return type:torch.Tensor

easy_guide

easy_guide(model)[source]

Convenience decorator to create an EasyGuide . The following are equivalent:

# Version 1. Decorate a function.
@easy_guide(model)
def guide(self, foo, bar):
    return my_guide(foo, bar)

# Version 2. Create and instantiate a subclass of EasyGuide.
class Guide(EasyGuide):
    def guide(self, foo, bar):
        return my_guide(foo, bar)
guide = Guide(model)

Note @easy_guide wrappers cannot be pickled; to build a guide that can be pickled, instead subclass from EasyGuide.

Parameters:model (callable) – a Pyro model.

Group

class Group(guide, sites)[source]

Bases: object

An autoguide helper to match a group of model sites.

Variables:
  • event_shape (torch.Size) – The total flattened concatenated shape of all matching sample sites in the model.
  • prototype_sites (list) – A list of all matching sample sites in a prototype trace of the model.
Parameters:
  • guide (EasyGuide) – An easyguide instance.
  • sites (list) – A list of model sites.
guide
sample(guide_name, fn, infer=None)[source]

Wrapper around pyro.sample() to create a single auxiliary sample site and then unpack to multiple sample sites for model replay.

Parameters:
  • guide_name (str) – The name of the auxiliary guide site.
  • fn (callable) – A distribution with shape self.event_shape.
  • infer (dict) – Optional inference configuration dict.
Returns:

A pair (guide_z, model_zs) where guide_z is the single concatenated blob and model_zs is a dict mapping site name to constrained model sample.

Return type:

tuple

map_estimate()[source]

Construct a maximum a posteriori (MAP) guide using Delta distributions.

Returns:A dict mapping model site name to sampled value.
Return type:dict