Parameters

Parameters in Pyro are basically thin wrappers around PyTorch Tensors that carry unique names. As such Parameters are the primary stateful objects in Pyro. Users typically interact with parameters via the Pyro primitive pyro.param. Parameters play a central role in stochastic variational inference, where they are used to represent point estimates for the parameters in parameterized families of models and guides.

ParamStore

class StateDict[source]

Bases: typing_extensions.TypedDict

params: Dict[str, torch.Tensor]
constraints: Dict[str, torch.distributions.constraints.Constraint]
class ParamStoreDict[source]

Bases: object

Global store for parameters in Pyro. This is basically a key-value store. The typical user interacts with the ParamStore primarily through the primitive pyro.param.

See Introduction for further discussion and SVI Part I for some examples.

Some things to bear in mind when using parameters in Pyro:

  • parameters must be assigned unique names

  • the init_tensor argument to pyro.param is only used the first time that a given (named) parameter is registered with Pyro.

  • for this reason, a user may need to use the clear() method if working in a REPL in order to get the desired behavior. this method can also be invoked with pyro.clear_param_store().

  • the internal name of a parameter within a PyTorch nn.Module that has been registered with Pyro is prepended with the Pyro name of the module. so nothing prevents the user from having two different modules each of which contains a parameter named weight. by contrast, a user can only have one top-level parameter named weight (outside of any module).

  • parameters can be saved and loaded from disk using save and load.

  • in general parameters are associated with both constrained and unconstrained values. for example, under the hood a parameter that is constrained to be positive is represented as an unconstrained tensor in log space.

clear() None[source]

Clear the ParamStore

items() Iterator[Tuple[str, torch.Tensor]][source]

Iterate over (name, constrained_param) pairs. Note that constrained_param is in the constrained (i.e. user-facing) space.

keys() KeysView[str][source]

Iterate over param names.

values() Iterator[torch.Tensor][source]

Iterate over constrained parameter values.

setdefault(name: str, init_constrained_value: Union[torch.Tensor, Callable[[], torch.Tensor]], constraint: torch.distributions.constraints.Constraint = Real()) torch.Tensor[source]

Retrieve a constrained parameter value from the ParamStoreDict if it exists, otherwise set the initial value. Note that this is a little fancier than dict.setdefault().

If the parameter already exists, init_constrained_tensor will be ignored. To avoid expensive creation of init_constrained_tensor you can wrap it in a lambda that will only be evaluated if the parameter does not already exist:

param_store.get("foo", lambda: (0.001 * torch.randn(1000, 1000)).exp(),
                constraint=constraints.positive)
Parameters
  • name (str) – parameter name

  • init_constrained_value (torch.Tensor or callable returning a torch.Tensor) – initial constrained value

  • constraint (Constraint) – torch constraint object

Returns

constrained parameter value

Return type

torch.Tensor

named_parameters() ItemsView[str, torch.Tensor][source]

Returns an iterator over (name, unconstrained_value) tuples for each parameter in the ParamStore. Note that, in the event the parameter is constrained, unconstrained_value is in the unconstrained space implicitly used by the constraint.

get_all_param_names() KeysView[str][source]
replace_param(param_name: str, new_param: torch.Tensor, old_param: torch.Tensor) None[source]
get_param(name: str, init_tensor: Optional[torch.Tensor] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor[source]

Get parameter from its name. If it does not yet exist in the ParamStore, it will be created and stored. The Pyro primitive pyro.param dispatches to this method.

Parameters
Returns

parameter

Return type

torch.Tensor

match(name: str) Dict[str, torch.Tensor][source]

Get all parameters that match regex. The parameter must exist.

Parameters

name (str) – regular expression

Returns

dict with key param name and value torch Tensor

param_name(p: torch.Tensor) Optional[str][source]

Get parameter name from parameter

Parameters

p – parameter

Returns

parameter name

get_state() pyro.params.param_store.StateDict[source]

Get the ParamStore state.

set_state(state: pyro.params.param_store.StateDict) None[source]

Set the ParamStore state using state from a previous get_state() call

save(filename: str) None[source]

Save parameters to file

Parameters

filename (str) – file name to save to

load(filename: str, map_location: Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] = None) None[source]

Loads parameters from file

Note

If using pyro.module() on parameters loaded from disk, be sure to set the update_module_params flag:

pyro.get_param_store().load('saved_params.save')
pyro.module('module', nn, update_module_params=True)
Parameters
  • filename (str) – file name to load from

  • map_location (function, torch.device, string or a dict) – specifies how to remap storage locations

scope(state: Optional[pyro.params.param_store.StateDict] = None) Iterator[pyro.params.param_store.StateDict][source]

Context manager for using multiple parameter stores within the same process.

This is a thin wrapper around get_state(), clear(), and set_state(). For large models where memory space is limiting, you may want to instead manually use save(), clear(), and load().

Example usage:

param_store = pyro.get_param_store()

# Train multiple models, while avoiding param name conflicts.
with param_store.scope() as scope1:
    # ...Train one model,guide pair...
with param_store.scope() as scope2:
    # ...Train another model,guide pair...

# Now evaluate each, still avoiding name conflicts.
with param_store.scope(scope1):  # loads the first model's scope
   # ...evaluate the first model...
with param_store.scope(scope2):  # loads the second model's scope
   # ...evaluate the second model...
param_with_module_name(pyro_name: str, param_name: str) str[source]
module_from_param_with_module_name(param_name: str) str[source]
user_param_name(param_name: str) str[source]
normalize_param_name(name: str) str[source]