Easy Custom Guides


class EasyGuide(model)[source]

Bases: object

Base class for “easy guides”.

Derived classes should define a guide() method. This guide() method can combine ordinary guide statements (e.g. pyro.sample and pyro.param) with the following special statements:

  • group = self.group(...) selects multiple pyro.sample sites in the model. See Group for subsequent methods.
  • with self.plate(...): ... should be used instead of pyro.plate.
  • self.map_estimate(...) uses a Delta guide for a single site.

Derived classes may also override the init() method to provide custom initialization for models sites.

Parameters:model (callable) – A Pyro model.
guide(*args, **kargs)[source]

Guide implementation, to be overridden by user.


Model initialization method, may be overridden by user.

This should input a site and output a valid sample from that site. The default behavior is to draw a random sample:

return site["fn"]()

For other possible initialization functions see http://docs.pyro.ai/en/stable/contrib.autoguide.html#module-pyro.contrib.autoguide.initialization

__call__(*args, **kwargs)[source]

Runs the guide. This is typically used by inference algorithms.

plate(name, size=None, subsample_size=None, subsample=None, *args, **kwargs)[source]

A wrapper around pyro.plate to allow EasyGuide to automatically construct plates. You should use this rather than pyro.plate inside your guide() implementation.


Select a Group of model sites for joint guidance.

Parameters:match (str) – A regex string matching names of model sample sites.
Returns:A group of model sites.
Return type:Group

Construct a maximum a posteriori (MAP) guide using Delta distributions.

Parameters:name (str) – The name of a model sample site.
Returns:A sampled value.
Return type:torch.Tensor



Convenience decorator to create an EasyGuide . The following are equivalent:

# Version 1. Decorate a function.
def guide(self, foo, bar):
    return my_guide(foo, bar)

# Version 2. Create and instantiate a subclass of EasyGuide.
class Guide(EasyGuide):
    def guide(self, foo, bar):
        return my_guide(foo, bar)
guide = Guide(model)
Parameters:model (callable) – a Pyro model.


class Group(guide, sites)[source]

Bases: object

An autoguide helper to match a group of model sites.

  • event_shape (torch.Size) – The total flattened concatenated shape of all matching sample sites in the model.
  • prototype_sites (list) – A list of all matching sample sites in a prototype trace of the model.
  • guide (EasyGuide) – An easyguide instance.
  • sites (list) – A list of model sites.
sample(guide_name, fn, infer=None)[source]

Wrapper around pyro.sample() to create a single auxiliary sample site and then unpack to multiple sample sites for model replay.

  • guide_name (str) – The name of the auxiliary guide site.
  • fn (callable) – A distribution with shape self.event_shape.
  • infer (dict) – Optional inference configuration dict.

A pair (guide_z, model_zs) where guide_z is the single concatenated blob and model_zs is a dict mapping site name to constrained model sample.

Return type:



Construct a maximum a posteriori (MAP) guide using Delta distributions.

Returns:A dict mapping model site name to sampled value.
Return type:dict