# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import sys
from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Literal,
Optional,
Set,
Tuple,
Union,
)
import opt_einsum
from pyro.distributions.score_parts import ScoreParts
from pyro.distributions.util import scale_and_mask
from pyro.ops.packed import pack
from pyro.poutine.util import is_validation_enabled
from pyro.util import warn_if_inf, warn_if_nan
if TYPE_CHECKING:
import torch
from pyro.distributions.distribution import Distribution
from pyro.poutine.runtime import Message
def allow_all_sites(name: str, site: "Message") -> bool:
return True
[docs]class Trace:
"""
Graph data structure denoting the relationships amongst different pyro primitives
in the execution trace.
An execution trace of a Pyro program is a record of every call
to ``pyro.sample()`` and ``pyro.param()`` in a single execution of that program.
Traces are directed graphs whose nodes represent primitive calls or input/output,
and whose edges represent conditional dependence relationships
between those primitive calls. They are created and populated by ``poutine.trace``.
Each node (or site) in a trace contains the name, input and output value of the site,
as well as additional metadata added by inference algorithms or user annotation.
In the case of ``pyro.sample``, the trace also includes the stochastic function
at the site, and any observed data added by users.
Consider the following Pyro program:
>>> def model(x):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(x, s))
... return z ** 2
We can record its execution using ``pyro.poutine.trace``
and use the resulting data structure to compute the log-joint probability
of all of the sample sites in the execution or extract all parameters.
>>> trace = pyro.poutine.trace(model).get_trace(0.0)
>>> logp = trace.log_prob_sum()
>>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
We can also inspect or manipulate individual nodes in the trace.
``trace.nodes`` contains a ``collections.OrderedDict``
of site names and metadata corresponding to ``x``, ``s``, ``z``, and the return value:
>>> list(name for name in trace.nodes.keys()) # doctest: +SKIP
["_INPUT", "s", "z", "_RETURN"]
Values of ``trace.nodes`` are dictionaries of node metadata:
>>> trace.nodes["z"] # doctest: +SKIP
{'type': 'sample', 'name': 'z', 'is_observed': False,
'fn': Normal(), 'value': tensor(0.6480), 'args': (), 'kwargs': {},
'infer': {}, 'scale': 1.0, 'cond_indep_stack': (),
'done': True, 'stop': False, 'continuation': None}
``'infer'`` is a dictionary of user- or algorithm-specified metadata.
``'args'`` and ``'kwargs'`` are the arguments passed via ``pyro.sample``
to ``fn.__call__`` or ``fn.log_prob``.
``'scale'`` is used to scale the log-probability of the site when computing the log-joint.
``'cond_indep_stack'`` contains data structures corresponding to ``pyro.plate`` contexts
appearing in the execution.
``'done'``, ``'stop'``, and ``'continuation'`` are only used by Pyro's internals.
:param string graph_type: string specifying the kind of trace graph to construct
"""
def __init__(self, graph_type: Literal["flat", "dense"] = "flat") -> None:
assert graph_type in ("flat", "dense"), "{} not a valid graph type".format(
graph_type
)
self.graph_type = graph_type
self.nodes: OrderedDict[str, "Message"] = OrderedDict()
self._succ: OrderedDict[str, Set[str]] = OrderedDict()
self._pred: OrderedDict[str, Set[str]] = OrderedDict()
def __contains__(self, name: str) -> bool:
return name in self.nodes
def __iter__(self) -> Iterable[str]:
return iter(self.nodes.keys())
def __len__(self) -> int:
return len(self.nodes)
@property
def edges(self) -> Iterable[Tuple[str, str]]:
for site, adj_nodes in self._succ.items():
for adj_node in adj_nodes:
yield site, adj_node
[docs] def add_node(self, site_name: str, **kwargs: Any) -> None:
"""
:param string site_name: the name of the site to be added
Adds a site to the trace.
Raises an error when attempting to add a duplicate node
instead of silently overwriting.
"""
if site_name in self:
site = self.nodes[site_name]
if site["type"] != kwargs["type"]:
# Cannot sample or observe after a param statement.
raise RuntimeError(
"{} is already in the trace as a {}".format(site_name, site["type"])
)
elif kwargs["type"] != "param":
# Cannot sample after a previous sample statement.
raise RuntimeError(
"Multiple {} sites named '{}'".format(kwargs["type"], site_name)
)
# XXX should copy in case site gets mutated, or dont bother?
self.nodes[site_name] = kwargs # type: ignore[assignment]
self._pred[site_name] = set()
self._succ[site_name] = set()
[docs] def add_edge(self, site1: str, site2: str) -> None:
for site in (site1, site2):
if site not in self.nodes:
self.add_node(site)
self._succ[site1].add(site2)
self._pred[site2].add(site1)
[docs] def remove_node(self, site_name: str) -> None:
self.nodes.pop(site_name)
for p in self._pred[site_name]:
self._succ[p].remove(site_name)
for s in self._succ[site_name]:
self._pred[s].remove(site_name)
self._pred.pop(site_name)
self._succ.pop(site_name)
[docs] def predecessors(self, site_name: str) -> Set[str]:
return self._pred[site_name]
[docs] def successors(self, site_name: str) -> Set[str]:
return self._succ[site_name]
[docs] def copy(self) -> "Trace":
"""
Makes a shallow copy of self with nodes and edges preserved.
"""
new_tr = Trace(graph_type=self.graph_type)
new_tr.nodes.update(self.nodes)
new_tr._succ.update(self._succ)
new_tr._pred.update(self._pred)
return new_tr
def _dfs(self, site: str, visited: Set[str]) -> Iterable[str]:
if site in visited:
return
for s in self._succ[site]:
for node in self._dfs(s, visited):
yield node
visited.add(site)
yield site
[docs] def topological_sort(self, reverse: bool = False) -> List[str]:
"""
Return a list of nodes (site names) in topologically sorted order.
:param bool reverse: Return the list in reverse order.
:return: list of topologically sorted nodes (site names).
"""
visited: Set[str] = set()
top_sorted = []
for s in self._succ:
for node in self._dfs(s, visited):
top_sorted.append(node)
return top_sorted if reverse else list(reversed(top_sorted))
[docs] def log_prob_sum(
self,
site_filter: Callable[[str, "Message"], bool] = allow_all_sites,
) -> Union["torch.Tensor", float]:
"""
Compute the site-wise log probabilities of the trace.
Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
Each ``log_prob_sum`` is a scalar.
The computation of ``log_prob_sum`` is memoized.
:returns: total log probability.
:rtype: torch.Tensor
"""
result = 0.0
for name, site in self.nodes.items():
if site["type"] == "sample" and site_filter(name, site):
if TYPE_CHECKING:
assert isinstance(site["fn"], Distribution)
if "log_prob_sum" in site:
log_p = site["log_prob_sum"]
else:
try:
log_p = site["fn"].log_prob(
site["value"], *site["args"], **site["kwargs"]
)
except ValueError as e:
_, exc_value, traceback = sys.exc_info()
shapes = self.format_shapes(last_site=site["name"])
raise ValueError(
"Error while computing log_prob_sum at site '{}':\n{}\n{}\n".format(
name, exc_value, shapes
)
).with_traceback(traceback) from e
log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum()
site["log_prob_sum"] = log_p
if is_validation_enabled():
warn_if_nan(log_p, "log_prob_sum at site '{}'".format(name))
warn_if_inf(
log_p,
"log_prob_sum at site '{}'".format(name),
allow_neginf=True,
)
result = result + log_p # type: ignore[assignment]
return result
[docs] def compute_log_prob(
self,
site_filter: Callable[[str, "Message"], bool] = allow_all_sites,
) -> None:
"""
Compute the site-wise log probabilities of the trace.
Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
Each ``log_prob_sum`` is a scalar.
Both computations are memoized.
"""
for name, site in self.nodes.items():
if site["type"] == "sample" and site_filter(name, site):
if TYPE_CHECKING:
assert isinstance(site["fn"], Distribution)
if "log_prob" not in site:
try:
log_p = site["fn"].log_prob(
site["value"], *site["args"], **site["kwargs"]
)
except ValueError as e:
_, exc_value, traceback = sys.exc_info()
shapes = self.format_shapes(last_site=site["name"])
raise ValueError(
"Error while computing log_prob at site '{}':\n{}\n{}".format(
name, exc_value, shapes
)
).with_traceback(traceback) from e
site["unscaled_log_prob"] = log_p
log_p = scale_and_mask(log_p, site["scale"], site["mask"])
site["log_prob"] = log_p
site["log_prob_sum"] = log_p.sum()
if is_validation_enabled():
warn_if_nan(
site["log_prob_sum"],
"log_prob_sum at site '{}'".format(name),
)
warn_if_inf(
site["log_prob_sum"],
"log_prob_sum at site '{}'".format(name),
allow_neginf=True,
)
[docs] def compute_score_parts(self) -> None:
"""
Compute the batched local score parts at each site of the trace.
Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
Each ``log_prob_sum`` is a scalar.
All computations are memoized.
"""
for name, site in self.nodes.items():
if site["type"] == "sample" and "score_parts" not in site:
if TYPE_CHECKING:
assert isinstance(site["fn"], Distribution)
# Note that ScoreParts overloads the multiplication operator
# to correctly scale each of its three parts.
try:
value = site["fn"].score_parts(
site["value"], *site["args"], **site["kwargs"]
)
except ValueError as e:
_, exc_value, traceback = sys.exc_info()
shapes = self.format_shapes(last_site=site["name"])
raise ValueError(
"Error while computing score_parts at site '{}':\n{}\n{}".format(
name, exc_value, shapes
)
).with_traceback(traceback) from e
site["unscaled_log_prob"] = value.log_prob
value = value.scale_and_mask(site["scale"], site["mask"])
site["score_parts"] = value
site["log_prob"] = value.log_prob
site["log_prob_sum"] = value.log_prob.sum()
if is_validation_enabled():
warn_if_nan(
site["log_prob_sum"], "log_prob_sum at site '{}'".format(name)
)
warn_if_inf(
site["log_prob_sum"],
"log_prob_sum at site '{}'".format(name),
allow_neginf=True,
)
[docs] def detach_(self) -> None:
"""
Detach values (in-place) at each sample site of the trace.
"""
for _, site in self.nodes.items():
if site["type"] == "sample":
assert site["value"] is not None
site["value"] = site["value"].detach()
@property
def observation_nodes(self) -> List[str]:
"""
:return: a list of names of observe sites
"""
return [
name
for name, node in self.nodes.items()
if node["type"] == "sample" and node["is_observed"]
]
@property
def param_nodes(self) -> List[str]:
"""
:return: a list of names of param sites
"""
return [name for name, node in self.nodes.items() if node["type"] == "param"]
@property
def stochastic_nodes(self) -> List[str]:
"""
:return: a list of names of sample sites
"""
return [
name
for name, node in self.nodes.items()
if node["type"] == "sample" and not node["is_observed"]
]
@property
def reparameterized_nodes(self) -> List[str]:
"""
:return: a list of names of sample sites whose stochastic functions
are reparameterizable primitive distributions
"""
return [
name
for name, node in self.nodes.items()
if node["type"] == "sample"
and not node["is_observed"]
and getattr(node["fn"], "has_rsample", False)
]
@property
def nonreparam_stochastic_nodes(self) -> List[str]:
"""
:return: a list of names of sample sites whose stochastic functions
are not reparameterizable primitive distributions
"""
return list(set(self.stochastic_nodes) - set(self.reparameterized_nodes))
[docs] def iter_stochastic_nodes(self) -> Iterator[Tuple[str, "Message"]]:
"""
:return: an iterator over stochastic nodes in the trace.
"""
for name, node in self.nodes.items():
if node["type"] == "sample" and not node["is_observed"]:
yield name, node
[docs] def symbolize_dims(self, plate_to_symbol: Optional[Dict[str, str]] = None) -> None:
"""
Assign unique symbols to all tensor dimensions.
"""
plate_to_symbol = {} if plate_to_symbol is None else plate_to_symbol
symbol_to_dim = {}
for site in self.nodes.values():
if site["type"] != "sample":
continue
# allocate even symbols for plate dims
dim_to_symbol: Dict[int, str] = {}
for frame in site["cond_indep_stack"]:
if frame.vectorized:
assert frame.dim is not None
if frame.name in plate_to_symbol:
symbol = plate_to_symbol[frame.name]
else:
symbol = opt_einsum.get_symbol(2 * len(plate_to_symbol))
plate_to_symbol[frame.name] = symbol
symbol_to_dim[symbol] = frame.dim
dim_to_symbol[frame.dim] = symbol
# allocate odd symbols for enum dims
assert site["infer"] is not None
for dim, id_ in site["infer"].get("_dim_to_id", {}).items():
symbol = opt_einsum.get_symbol(1 + 2 * id_)
symbol_to_dim[symbol] = dim
dim_to_symbol[dim] = symbol
enum_dim = site["infer"].get("_enumerate_dim")
if enum_dim is not None:
site["infer"]["_enumerate_symbol"] = dim_to_symbol[enum_dim]
site["infer"]["_dim_to_symbol"] = dim_to_symbol
self.plate_to_symbol = plate_to_symbol
self.symbol_to_dim = symbol_to_dim
[docs] def pack_tensors(self, plate_to_symbol: Optional[Dict[str, str]] = None) -> None:
"""
Computes packed representations of tensors in the trace.
This should be called after :meth:`compute_log_prob` or :meth:`compute_score_parts`.
"""
self.symbolize_dims(plate_to_symbol)
for site in self.nodes.values():
if site["type"] != "sample":
continue
assert site["infer"] is not None
dim_to_symbol = site["infer"]["_dim_to_symbol"]
packed = site.setdefault("packed", {})
try:
packed["mask"] = pack(site["mask"], dim_to_symbol)
if "score_parts" in site:
log_prob, score_function, entropy_term = site["score_parts"]
log_prob = pack(log_prob, dim_to_symbol)
score_function = pack(score_function, dim_to_symbol)
entropy_term = pack(entropy_term, dim_to_symbol)
packed["score_parts"] = ScoreParts(
log_prob, score_function, entropy_term
)
packed["log_prob"] = log_prob
packed["unscaled_log_prob"] = pack(
site["unscaled_log_prob"], dim_to_symbol
)
elif "log_prob" in site:
packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
packed["unscaled_log_prob"] = pack(
site["unscaled_log_prob"], dim_to_symbol
)
except ValueError as e:
_, exc_value, traceback = sys.exc_info()
shapes = self.format_shapes(last_site=site["name"])
raise ValueError(
"Error while packing tensors at site '{}':\n {}\n{}".format(
site["name"], exc_value, shapes
)
).with_traceback(traceback) from e
def _format_table(rows: List[List[Optional[str]]]) -> str:
"""
Formats a right justified table using None as column separator.
"""
# compute column widths
column_widths = [0, 0, 0]
for row in rows:
widths = [0, 0, 0]
j = 0
for cell in row:
if cell is None:
j += 1
else:
widths[j] += 1
for j in range(3):
column_widths[j] = max(column_widths[j], widths[j])
# justify columns
justified_rows: List[List[str]] = []
for row in rows:
cols: List[List[str]] = [[], [], []]
j = 0
for cell in row:
if cell is None:
j += 1
else:
cols[j].append(cell)
cols = [
(
[""] * (width - len(col)) + col
if direction == "r"
else col + [""] * (width - len(col))
)
for width, col, direction in zip(column_widths, cols, "rrl")
]
justified_rows.append(sum(cols, []))
# compute cell widths
cell_widths = [0] * len(justified_rows[0])
for justified_row in justified_rows:
for j, cell in enumerate(justified_row):
cell_widths[j] = max(cell_widths[j], len(cell))
# justify cells
return "\n".join(
" ".join(cell.rjust(width) for cell, width in zip(justified_row, cell_widths))
for justified_row in justified_rows
)