Funsor-based Pyro

Primitives

clear_param_store()[source]

Clears the global ParamStoreDict.

This is especially useful if you’re working in a REPL. We recommend calling this before each training loop (to avoid leaking parameters from past models), and before each unit test (to avoid leaking parameters across tests).

condition(fn=None, *args, **kwargs)

Convenient wrapper of ConditionMessenger

Given a stochastic function with some sample statements and a dictionary of observations at names, change the sample statements at those names into observes with those values.

Consider the following Pyro program:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

To observe a value for site z, we can write

>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})

This is equivalent to adding obs=value as a keyword argument to pyro.sample(“z”, …) in model.

Parameters
  • fn – a stochastic function (callable containing Pyro primitive calls)

  • data – a dict or a Trace

Returns

stochastic function decorated with a ConditionMessenger

deterministic(name, value, event_dim=None)[source]

Deterministic statement to add a Delta site with name name and value value to the trace. This is useful when we want to record values which are completely determined by their parents. For example:

x = pyro.sample("x", dist.Normal(0, 1))
x2 = pyro.deterministic("x2", x ** 2)

Note

The site does not affect the model density. This currently converts to a sample() statement, but may change in the future.

Parameters
  • name (str) – Name of the site.

  • value (torch.Tensor) – Value of the site.

  • event_dim (int) – Optional event dimension, defaults to value.ndim.

do(fn=None, *args, **kwargs)

Convenient wrapper of DoMessenger

Given a stochastic function with some sample statements and a dictionary of values at names, set the return values of those sites equal to the values as if they were hard-coded to those values and introduce fresh sample sites with the same names whose values do not propagate.

Composes freely with condition() to represent counterfactual distributions over potential outcomes. See Single World Intervention Graphs [1] for additional details and theory.

Consider the following Pyro program:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

To intervene with a value for site z, we can write

>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})

This is equivalent to replacing z = pyro.sample(“z”, …) with z = torch.tensor(1.) and introducing a fresh sample site pyro.sample(“z”, …) whose value is not used elsewhere.

References

[1] Single World Intervention Graphs: A Primer,

Thomas Richardson, James Robins

Parameters
  • fn – a stochastic function (callable containing Pyro primitive calls)

  • data – a dict mapping sample site names to interventions

Returns

stochastic function decorated with a DoMessenger

enable_validation(is_validate=True)[source]

Enable or disable validation checks in Pyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution arguments and support values, detecting incorrect use of ELBO and MCMC. Since some of these checks may be expensive, you may want to disable validation of mature models to speed up inference.

The default behavior mimics Python’s assert statement: validation is on by default, but is disabled if Python is run in optimized mode (via python -O). Equivalently, the default behavior depends on Python’s global __debug__ value via pyro.enable_validation(__debug__).

Validation is temporarily disabled during jit compilation, for all inference algorithms that support the PyTorch jit. We recommend developing models with non-jitted inference algorithms to ease debugging, then optionally moving to jitted inference once a model is correct.

Parameters

is_validate (bool) – (optional; defaults to True) whether to enable validation checks.

factor(name, log_factor, *, has_rsample=None)[source]

Factor statement to add arbitrary log probability factor to a probabilisitic model.

Warning

When using factor statements in guides, you’ll need to specify whether the factor statement originated from fully reparametrized sampling (e.g. the Jacobian determinant of a transformation of a reparametrized variable) or from nonreparameterized sampling (e.g. discrete samples). For the fully reparametrized case, set has_rsample=True; for the nonreparametrized case, set has_rsample=False. This is needed only in guides, not in models.

Parameters
  • name (str) – Name of the trivial sample

  • log_factor (torch.Tensor) – A possibly batched log probability factor.

  • has_rsample (bool) – Whether the log_factor arose from a fully reparametrized distribution. Defaults to False when used in models, but must be specified for use in guides.

get_param_store()[source]

Returns the global ParamStoreDict.

markov(fn=None, *args, **kwargs)

Convenient wrapper of MarkovMessenger

Handler for converting to/from funsors consistent with Pyro’s positional batch dimensions.

Parameters
  • history (int) – The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to pyro.plate.

  • keep (bool) – If true, frames are replayable. This is important when branching: if keep=True, neighboring branches at the same level can depend on each other; if keep=False, neighboring branches are independent (conditioned on their shared ancestors).

module(name, nn_module, update_module_params=False)[source]

Registers all parameters of a torch.nn.Module with Pyro’s param_store. In conjunction with the ParamStoreDict save() and load() functionality, this allows the user to save and load modules.

Note

Consider instead using PyroModule, a newer alternative to pyro.module() that has better support for: jitting, serving in C++, and converting parameters to random variables. For details see the Modules Tutorial .

Parameters
  • name (str) – name of module

  • nn_module (torch.nn.Module) – the module to be registered with Pyro

  • update_module_params – determines whether Parameters in the PyTorch module get overridden with the values found in the ParamStore (if any). Defaults to False

Returns

torch.nn.Module

param(name, init_tensor=None, constraint=Real(), event_dim=None)[source]

Saves the variable as a parameter in the param store. To interact with the param store or write to disk, see Parameters.

Parameters
  • name (str) – name of parameter

  • init_tensor (torch.Tensor or callable) – initial tensor or lazy callable that returns a tensor. For large tensors, it may be cheaper to write e.g. lambda: torch.randn(100000), which will only be evaluated on the initial statement.

  • constraint (torch.distributions.constraints.Constraint) – torch constraint, defaults to constraints.real.

  • event_dim (int) – (optional) number of rightmost dimensions unrelated to batching. Dimension to the left of this will be considered batch dimensions; if the param statement is inside a subsampled plate, then corresponding batch dimensions of the parameter will be correspondingly subsampled. If unspecified, all dimensions will be considered event dims and no subsampling will be performed.

Returns

A constrained parameter. The underlying unconstrained parameter is accessible via pyro.param(...).unconstrained(), where .unconstrained is a weakref attribute.

Return type

torch.Tensor

random_module(name, nn_module, prior, *args, **kwargs)[source]

Warning

The random_module primitive is deprecated, and will be removed in a future release. Use PyroModule instead to to create Bayesian modules from torch.nn.Module instances. See the Bayesian Regression tutorial for an example.

DEPRECATED Places a prior over the parameters of the module nn_module. Returns a distribution (callable) over nn.Modules, which upon calling returns a sampled nn.Module.

Parameters
  • name (str) – name of pyro module

  • nn_module (torch.nn.Module) – the module to be registered with pyro

  • prior – pyro distribution, stochastic function, or python dict with parameter names as keys and respective distributions/stochastic functions as values.

Returns

a callable which returns a sampled module

sample(name, fn, *args, **kwargs)[source]

Calls the stochastic function fn with additional side-effects depending on name and the enclosing context (e.g. an inference algorithm). See Introduction to Pyro for a discussion.

Parameters
  • name – name of sample

  • fn – distribution class or function

  • obs – observed datum (optional; should only be used in context of inference) optionally specified in kwargs

  • obs_mask (bool or Tensor) – Optional boolean tensor mask of shape broadcastable with fn.batch_shape. If provided, events with mask=True will be conditioned on obs and remaining events will be imputed by sampling. This introduces a latent sample site named name + "_unobserved" which should be used by guides.

  • infer (dict) – Optional dictionary of inference parameters specified in kwargs. See inference documentation for details.

Returns

sample

set_rng_seed(rng_seed)[source]

Sets seeds of torch and torch.cuda (if available).

Parameters

rng_seed (int) – The seed value.

subsample(data, event_dim)[source]

Subsampling statement to subsample data tensors based on enclosing plate s.

This is typically called on arguments to model() when subsampling is performed automatically by plate s by passing either the subsample or subsample_size kwarg. For example the following are equivalent:

# Version 1. using indexing
def model(data):
    with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind:
        data = data[ind]
        # ...

# Version 2. using pyro.subsample()
def model(data):
    with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()):
        data = pyro.subsample(data, event_dim=0)
        # ...
Parameters
  • data (Tensor) – A tensor of batched data.

  • event_dim (int) – The event dimension of the data tensor. Dimensions to the left are considered batch dimensions.

Returns

A subsampled version of data

Return type

Tensor

to_data(x, name_to_dim=None, dim_type=DimType.LOCAL)[source]
to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL)[source]
vectorized_markov(fn=None, *args, **kwargs)

Convenient wrapper of VectorizedMarkovMessenger

Construct for Markov chain of variables designed for efficient elimination of Markov dimensions using the parallel-scan algorithm. Whenever permissible, vectorized_markov is interchangeable with markov.

The for loop generates both int and 1-dimensional torch.Tensor indices: (0, ..., history-1, torch.arange(0, size-history), ..., torch.arange(history, size)). int indices are used to initiate the Markov chain and torch.Tensor indices are used to construct vectorized transition probabilities for efficient elimination by the parallel-scan algorithm.

When history==0 vectorized_markov behaves similar to plate.

After the for loop is run, Markov variables are identified and then the step information is constructed and added to the trace. step informs inference algorithms which variables belong to a Markov chain.

data = torch.ones(3, dtype=torch.float)

def model(data, vectorized=True):

    init = pyro.param("init", lambda: torch.rand(3), constraint=constraints.simplex)
    trans = pyro.param("trans", lambda: torch.rand((3, 3)), constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(3,))

    markov_chain = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-1) if vectorized \
        else pyro.markov(range(len(data)))
    for i in markov_chain:
        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(
            init if isinstance(i, int) and i < 1 else trans[x_prev]),

        pyro.sample("y_{}".format(i),
                    dist.Normal(Vindex(locs)[..., x_curr], 1.),
                    obs=data[i])
        x_prev = x_curr

#  trace.nodes["time"]["value"]
#  frozenset({('x_0', 'x_slice(0, 2, None)', 'x_slice(1, 3, None)')})
#
#  pyro.vectorized_markov trace
#  ...
#  Sample Sites:
#      locs dist               | 3
#          value               | 3
#       log_prob               |
#       x_0 dist               |
#          value     3 1 1 1 1 |
#       log_prob     3 1 1 1 1 |
#       y_0 dist     3 1 1 1 1 |
#          value               |
#       log_prob     3 1 1 1 1 |
#  x_slice(1, 3, None) dist   3 1 1 1 1 2 |
#          value 3 1 1 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 2 |
#  y_slice(1, 3, None) dist 3 1 1 1 1 1 2 |
#          value             2 |
#       log_prob 3 1 1 1 1 1 2 |
#
#  pyro.markov trace
#  ...
#  Sample Sites:
#      locs dist             | 3
#          value             | 3
#       log_prob             |
#       x_0 dist             |
#          value   3 1 1 1 1 |
#       log_prob   3 1 1 1 1 |
#       y_0 dist   3 1 1 1 1 |
#          value             |
#       log_prob   3 1 1 1 1 |
#       x_1 dist   3 1 1 1 1 |
#          value 3 1 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 |
#       y_1 dist 3 1 1 1 1 1 |
#          value             |
#       log_prob 3 1 1 1 1 1 |
#       x_2 dist 3 1 1 1 1 1 |
#          value   3 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 |
#       y_2 dist   3 1 1 1 1 |
#          value             |
#       log_prob   3 1 1 1 1 |

Warning

This is only valid if there is only one Markov dimension per branch.

Parameters
  • name (str) – A unique name of a Markov dimension to help inference algorithm eliminate variables in the Markov chain.

  • size (int) – Length (size) of the Markov chain.

  • dim (int) – An optional dimension to use for this Markov dimension. If specified, dim should be negative, i.e. should index from the right. If not specified, dim is set to the rightmost dim that is left of all enclosing plate contexts.

  • history (int) – Memory (order) of the Markov chain. Also the number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to plate.

Returns

Returns both int and 1-dimensional torch.Tensor indices: (0, ..., history-1, torch.arange(size-history), ..., torch.arange(history, size)).

Effect handlers

enum(fn=None, *args, **kwargs)

Convenient wrapper of EnumMessenger

This version of EnumMessenger uses to_data() to allocate a fresh enumeration dim for each discrete sample site.

markov(fn=None, *args, **kwargs)

Convenient wrapper of MarkovMessenger

Handler for converting to/from funsors consistent with Pyro’s positional batch dimensions.

Parameters
  • history (int) – The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to pyro.plate.

  • keep (bool) – If true, frames are replayable. This is important when branching: if keep=True, neighboring branches at the same level can depend on each other; if keep=False, neighboring branches are independent (conditioned on their shared ancestors).

named(fn=None, *args, **kwargs)

Convenient wrapper of NamedMessenger

Base effect handler class for the to_funsor() and to_data() primitives. Any effect handlers that invoke these primitives internally or wrap code that does should inherit from NamedMessenger.

This design ensures that the global name-dim mapping is reset upon handler exit rather than potentially persisting until the entire program terminates.

plate(fn=None, *args, **kwargs)

Convenient wrapper of PlateMessenger

Combines new IndepMessenger implementation with existing pyro.poutine.BroadcastMessenger. Should eventually be a drop-in replacement for pyro.plate.

replay(fn=None, *args, **kwargs)

Convenient wrapper of ReplayMessenger

This version of ReplayMessenger is almost identical to the original version, except that it calls to_data() on the replayed funsor values. This may result in different unpacked shapes, but should produce correct allocations.

trace(fn=None, *args, **kwargs)

Convenient wrapper of TraceMessenger

Setting pack_online=True packs online instead of after the fact, converting all distributions and values to Funsors as soon as they are available.

Setting pack_online=False computes information necessary to do packing after execution. Each sample site is annotated with a dim_to_name dictionary, which can be passed directly to to_funsor().

vectorized_markov(fn=None, *args, **kwargs)

Convenient wrapper of VectorizedMarkovMessenger

Construct for Markov chain of variables designed for efficient elimination of Markov dimensions using the parallel-scan algorithm. Whenever permissible, vectorized_markov is interchangeable with markov.

The for loop generates both int and 1-dimensional torch.Tensor indices: (0, ..., history-1, torch.arange(0, size-history), ..., torch.arange(history, size)). int indices are used to initiate the Markov chain and torch.Tensor indices are used to construct vectorized transition probabilities for efficient elimination by the parallel-scan algorithm.

When history==0 vectorized_markov behaves similar to plate.

After the for loop is run, Markov variables are identified and then the step information is constructed and added to the trace. step informs inference algorithms which variables belong to a Markov chain.

data = torch.ones(3, dtype=torch.float)

def model(data, vectorized=True):

    init = pyro.param("init", lambda: torch.rand(3), constraint=constraints.simplex)
    trans = pyro.param("trans", lambda: torch.rand((3, 3)), constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(3,))

    markov_chain = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-1) if vectorized \
        else pyro.markov(range(len(data)))
    for i in markov_chain:
        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(
            init if isinstance(i, int) and i < 1 else trans[x_prev]),

        pyro.sample("y_{}".format(i),
                    dist.Normal(Vindex(locs)[..., x_curr], 1.),
                    obs=data[i])
        x_prev = x_curr

#  trace.nodes["time"]["value"]
#  frozenset({('x_0', 'x_slice(0, 2, None)', 'x_slice(1, 3, None)')})
#
#  pyro.vectorized_markov trace
#  ...
#  Sample Sites:
#      locs dist               | 3
#          value               | 3
#       log_prob               |
#       x_0 dist               |
#          value     3 1 1 1 1 |
#       log_prob     3 1 1 1 1 |
#       y_0 dist     3 1 1 1 1 |
#          value               |
#       log_prob     3 1 1 1 1 |
#  x_slice(1, 3, None) dist   3 1 1 1 1 2 |
#          value 3 1 1 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 2 |
#  y_slice(1, 3, None) dist 3 1 1 1 1 1 2 |
#          value             2 |
#       log_prob 3 1 1 1 1 1 2 |
#
#  pyro.markov trace
#  ...
#  Sample Sites:
#      locs dist             | 3
#          value             | 3
#       log_prob             |
#       x_0 dist             |
#          value   3 1 1 1 1 |
#       log_prob   3 1 1 1 1 |
#       y_0 dist   3 1 1 1 1 |
#          value             |
#       log_prob   3 1 1 1 1 |
#       x_1 dist   3 1 1 1 1 |
#          value 3 1 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 |
#       y_1 dist 3 1 1 1 1 1 |
#          value             |
#       log_prob 3 1 1 1 1 1 |
#       x_2 dist 3 1 1 1 1 1 |
#          value   3 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 |
#       y_2 dist   3 1 1 1 1 |
#          value             |
#       log_prob   3 1 1 1 1 |

Warning

This is only valid if there is only one Markov dimension per branch.

Parameters
  • name (str) – A unique name of a Markov dimension to help inference algorithm eliminate variables in the Markov chain.

  • size (int) – Length (size) of the Markov chain.

  • dim (int) – An optional dimension to use for this Markov dimension. If specified, dim should be negative, i.e. should index from the right. If not specified, dim is set to the rightmost dim that is left of all enclosing plate contexts.

  • history (int) – Memory (order) of the Markov chain. Also the number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to plate.

Returns

Returns both int and 1-dimensional torch.Tensor indices: (0, ..., history-1, torch.arange(size-history), ..., torch.arange(history, size)).

class NamedMessenger(first_available_dim=None)[source]

Bases: pyro.poutine.reentrant_messenger.ReentrantMessenger

Base effect handler class for the to_funsor() and to_data() primitives. Any effect handlers that invoke these primitives internally or wrap code that does should inherit from NamedMessenger.

This design ensures that the global name-dim mapping is reset upon handler exit rather than potentially persisting until the entire program terminates.

class MarkovMessenger(history=1, keep=False)[source]

Bases: pyro.contrib.funsor.handlers.named_messenger.NamedMessenger

Handler for converting to/from funsors consistent with Pyro’s positional batch dimensions.

Parameters
  • history (int) – The number of previous contexts visible from the current context. Defaults to 1. If zero, this is similar to pyro.plate.

  • keep (bool) – If true, frames are replayable. This is important when branching: if keep=True, neighboring branches at the same level can depend on each other; if keep=False, neighboring branches are independent (conditioned on their shared ancestors).

class GlobalNamedMessenger(first_available_dim=None)[source]

Bases: pyro.contrib.funsor.handlers.named_messenger.NamedMessenger

Base class for any new effect handlers that use the to_funsor() and to_data() primitives to allocate DimType.GLOBAL or DimType.VISIBLE dimensions.

Serves as a manual “scope” for dimensions that should not be recycled by MarkovMessenger: global dimensions will be considered active until the innermost GlobalNamedMessenger under which they were initially allocated exits.

to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL)[source]
to_data(x, name_to_dim=None, dim_type=DimType.LOCAL)[source]
class StackFrame(name_to_dim, dim_to_name, history=1, keep=False)[source]

Bases: object

Consistent bidirectional mapping between integer positional dimensions and names. Can be queried like a dictionary (value = frame[key], frame[key] = value).

class DimType(value)[source]

Bases: enum.Enum

Enumerates the possible types of dimensions to allocate

LOCAL = 0
GLOBAL = 1
VISIBLE = 2
class DimRequest(value, dim_type)

Bases: tuple

property dim_type

Alias for field number 1

property value

Alias for field number 0

class DimStack[source]

Bases: object

Single piece of global state to keep track of the mapping between names and dimensions.

Replaces the plate _DimAllocator, the enum _EnumAllocator, the stack in MarkovMessenger, _param_dims and _value_dims in EnumMessenger, and dim_to_symbol in msg['infer']

MAX_DIM = -25
DEFAULT_FIRST_DIM = -5
set_first_available_dim(dim)[source]
push_global(frame)[source]
pop_global()[source]
push_iter(frame)[source]
pop_iter()[source]
push_local(frame)[source]
pop_local()[source]
property global_frame
property local_frame
property current_write_env
property current_read_env

Collect all frames necessary to compute the full name <–> dim mapping and interpret Funsor inputs or batch shapes at any point in a computation.

allocate(key_to_value_request)[source]
names_from_batch_shape(batch_shape, dim_type=DimType.LOCAL)[source]

Inference algorithms

class ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]
loss(model, guide, *args, **kwargs)[source]

See pyro.infer.traceenum_elbo.TraceEnum_ELBO.loss()

loss_and_grads(model, guide, *args, **kwargs)[source]

See pyro.infer.traceenum_elbo.TraceEnum_ELBO.loss_and_grads()

class Jit_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]
class Trace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]

See pyro.infer.trace_elbo.Trace_ELBO.differentiable_loss()

class JitTrace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.contrib.funsor.infer.elbo.Jit_ELBO, pyro.contrib.funsor.infer.trace_elbo.Trace_ELBO

apply_optimizer(x)[source]
terms_from_trace(tr)[source]

Helper function to extract elbo components from execution traces.

class TraceMarkovEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]

See pyro.infer.traceenum_elbo.TraceEnum_ELBO.differentiable_loss()

class TraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]

See pyro.infer.traceenum_elbo.TraceEnum_ELBO.differentiable_loss()

class JitTraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.contrib.funsor.infer.elbo.Jit_ELBO, pyro.contrib.funsor.infer.traceenum_elbo.TraceEnum_ELBO

class JitTraceMarkovEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.contrib.funsor.infer.elbo.Jit_ELBO, pyro.contrib.funsor.infer.traceenum_elbo.TraceMarkovEnum_ELBO

class TraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]

See pyro.infer.tracetmc_elbo.TraceTMC_ELBO.differentiable_loss()

class JitTraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

Bases: pyro.contrib.funsor.infer.elbo.Jit_ELBO, pyro.contrib.funsor.infer.tracetmc_elbo.TraceTMC_ELBO