# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import argparse
import warnings
import weakref
import torch
import pyro
import pyro.poutine as poutine
from pyro.util import ignore_jit_warnings, optional, timed
def _hash(value, allow_id):
try:
hash(value)
return value
except TypeError as e:
if isinstance(value, list):
return tuple(_hash(x, allow_id) for x in value)
elif isinstance(value, dict):
return tuple(
sorted(
(_hash(x, allow_id), _hash(y, allow_id)) for x, y in value.items()
)
)
elif isinstance(value, set):
return frozenset(_hash(x, allow_id) for x in value)
elif isinstance(value, argparse.Namespace):
return str(value)
elif allow_id:
return id(value)
raise e
def _hashable_args_kwargs(args, kwargs):
items = sorted(kwargs.items())
hashable_kwargs = tuple((key, _hash(value, False)) for key, value in items)
try:
hash(hashable_kwargs)
except TypeError:
warnings.warn("Failed to hash kwargs; attempting to hash by id.")
hashable_kwargs = tuple((key, _hash(value, True)) for key, value in items)
return len(args), hashable_kwargs
class CompiledFunction:
"""
Output type of :func:`pyro.ops.jit.trace`.
Wrapper around the output of :func:`torch.jit.trace`
that handles parameter plumbing.
The actual PyTorch compilation artifact is stored in :attr:`compiled`.
Call diagnostic methods on this attribute.
"""
def __init__(self, fn, ignore_warnings=False, jit_options=None):
self.fn = fn
self.compiled = {} # len(args) -> callable
self.ignore_warnings = ignore_warnings
self.jit_options = {} if jit_options is None else jit_options
self.jit_options.setdefault("check_trace", False)
self.compile_time = None
self._param_names = None
def __call__(self, *args, **kwargs):
key = _hashable_args_kwargs(args, kwargs)
# if first time
if key not in self.compiled:
# param capture
with poutine.block():
with poutine.trace(param_only=True) as first_param_capture:
self.fn(*args, **kwargs)
self._param_names = list(set(first_param_capture.trace.nodes.keys()))
unconstrained_params = tuple(
pyro.param(name).unconstrained() for name in self._param_names
)
params_and_args = unconstrained_params + args
weakself = weakref.ref(self)
def compiled(*params_and_args):
self = weakself()
unconstrained_params = params_and_args[: len(self._param_names)]
args = params_and_args[len(self._param_names) :]
constrained_params = {}
for name, unconstrained_param in zip(
self._param_names, unconstrained_params
):
constrained_param = pyro.param(
name
) # assume param has been initialized
assert constrained_param.unconstrained() is unconstrained_param
constrained_params[name] = constrained_param
return poutine.replay(self.fn, params=constrained_params)(
*args, **kwargs
)
if self.ignore_warnings:
compiled = ignore_jit_warnings()(compiled)
with pyro.validation_enabled(False):
time_compilation = self.jit_options.pop("time_compilation", False)
with optional(timed(), time_compilation) as t:
self.compiled[key] = torch.jit.trace(
compiled, params_and_args, **self.jit_options
)
if time_compilation:
self.compile_time = t.elapsed
else:
unconstrained_params = [
pyro.param(name).unconstrained() for name in self._param_names
]
params_and_args = unconstrained_params + list(args)
with poutine.block(hide=self._param_names):
with poutine.trace(param_only=True) as param_capture:
ret = self.compiled[key](*params_and_args)
for name in param_capture.trace.nodes.keys():
if name not in self._param_names:
raise NotImplementedError(
"pyro.ops.jit.trace assumes all params are created on "
"first invocation, but found new param: {}".format(name)
)
return ret
[docs]def trace(fn=None, ignore_warnings=False, jit_options=None):
"""
Lazy replacement for :func:`torch.jit.trace` that works with
Pyro functions that call :func:`pyro.param`.
The actual compilation artifact is stored in the ``compiled`` attribute of
the output. Call diagnostic methods on this attribute.
Example::
def model(x):
scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive)
return pyro.sample("y", dist.Normal(x, scale))
@pyro.ops.jit.trace
def model_log_prob_fn(x, y):
cond_model = pyro.condition(model, data={"y": y})
tr = pyro.poutine.trace(cond_model).get_trace(x)
return tr.log_prob_sum()
:param callable fn: The function to be traced.
:param bool ignore_warnins: Whether to ignore jit warnings.
:param dict jit_options: Optional dict of options to pass to
:func:`torch.jit.trace` , e.g. ``{"optimize": False}``.
"""
if fn is None:
return lambda fn: trace(
fn, ignore_warnings=ignore_warnings, jit_options=jit_options
)
return CompiledFunction(
fn, ignore_warnings=ignore_warnings, jit_options=jit_options
)