import opt_einsum
from pyro.util import ignore_jit_warnings
_PATH_CACHE = {}
[docs]def contract_expression(equation, *shapes, **kwargs):
"""
Wrapper around :func:`opt_einsum.contract_expression` that optionally uses
Pyro's cheap optimizer and optionally caches contraction paths.
:param bool cache_path: whether to cache the contraction path.
Defaults to True.
"""
# memoize the contraction path
cache_path = kwargs.pop('cache_path', True)
if cache_path:
kwargs_key = tuple(kwargs.items())
key = equation, shapes, kwargs_key
if key in _PATH_CACHE:
return _PATH_CACHE[key]
expr = opt_einsum.contract_expression(equation, *shapes, **kwargs)
if cache_path:
_PATH_CACHE[key] = expr
return expr
[docs]def contract(equation, *operands, **kwargs):
"""
Wrapper around :func:`opt_einsum.contract` that optionally uses Pyro's
cheap optimizer and optionally caches contraction paths.
:param bool cache_path: whether to cache the contraction path.
Defaults to True.
"""
backend = kwargs.pop('backend', 'numpy')
out = kwargs.pop('out', None)
shapes = [tuple(t.shape) for t in operands]
with ignore_jit_warnings():
expr = contract_expression(equation, *shapes)
return expr(*operands, backend=backend, out=out)
__all__ = ['contract', 'contract_expression']