diff --git a/probe_py/probe_py/analysis.py b/probe_py/probe_py/analysis.py index d671c0ef..b64e0316 100644 --- a/probe_py/probe_py/analysis.py +++ b/probe_py/probe_py/analysis.py @@ -1,7 +1,6 @@ import collections -from .ptypes import ProbeLog +from .ptypes import ProbeLog, HbGraph, OpQuad from .ops import CloneOp, WaitOp -from .hb_graph import HbGraph, OpNode def get_max_parallelism_latest(hb_graph: HbGraph, probe_log: ProbeLog) -> int: @@ -10,14 +9,14 @@ def get_max_parallelism_latest(hb_graph: HbGraph, probe_log: ProbeLog) -> int: counter = 1 max_counter = 1 start_node = [node for node in hb_graph.nodes() if hb_graph.in_degree(node) == 0][0] - queue = collections.deque[tuple[OpNode, OpNode | None]]([(start_node, None)]) # (current_node, parent_node) + queue = collections.deque[tuple[OpQuad, OpQuad | None]]([(start_node, None)]) # (current_node, parent_node) while queue: node, parent = queue.popleft() if node in visited: continue if parent: - parent_op = probe_log.get_op(*parent.op_quad()).data - node_op = probe_log.get_op(*node.op_quad()).data + parent_op = probe_log.get_op(parent).data + node_op = probe_log.get_op(node).data visited.add(node) diff --git a/probe_py/probe_py/dataflow_graph.py b/probe_py/probe_py/dataflow_graph.py index 552c82a0..abb79aba 100644 --- a/probe_py/probe_py/dataflow_graph.py +++ b/probe_py/probe_py/dataflow_graph.py @@ -1,5 +1,4 @@ from __future__ import annotations -import tqdm import collections import dataclasses import enum @@ -10,8 +9,7 @@ import warnings import networkx from . import graph_utils -from . import hb_graph -from .hb_graph_accesses import hb_graph_to_accesses, Access, AccessMode, Phase +from .hb_graph_accesses import hb_graph_to_accesses from . import ops from . import ptypes @@ -19,21 +17,6 @@ _Node = typing.TypeVar("_Node") -@dataclasses.dataclass(frozen=True) -class AccessEpoch[_Node]: - """An access epoch is a set of nodes, denoted by a segment, in which the node may be accessed.""" - mode: AccessMode - bounds: graph_utils.Segment[_Node] - version: int | None = None - - -@dataclasses.dataclass(frozen=True) -class ExecNode: - """An exec, denoted by Pid and ExecNo""" - pid: ptypes.Pid - exec_no: ptypes.ExecNo - - @dataclasses.dataclass(frozen=True) class InodeVersionNode: """A particular version of the inode""" @@ -45,18 +28,16 @@ def __str__(self) -> str: if typing.TYPE_CHECKING: - DataflowGraph: typing.TypeAlias = networkx.DiGraph[hb_graph.OpNode | InodeVersionNode] - CompressedDataflowGraph: typing.TypeAlias = networkx.DiGraph[hb_graph.OpNode | frozenset[InodeVersionNode]] - EpochGraph: typing.TypeAlias = networkx.DiGraph[AccessEpoch[hb_graph.OpNode]] + DataflowGraph: typing.TypeAlias = networkx.DiGraph[ptypes.OpQuint | InodeVersionNode] + CompressedDataflowGraph: typing.TypeAlias = networkx.DiGraph[ptypes.OpQuint | frozenset[InodeVersionNode]] else: DataflowGraph = networkx.DiGraph CompressedDataflowGraph = networkx.DiGraph - EpochGraph = networkx.DiGraph def accesses_to_dataflow_graph( probe_log: ptypes.ProbeLog, - accesses_and_nodes: list[Access | hb_graph.OpNode], + accesses_and_quads: list[ptypes.Access | ptypes.OpQuad], ) -> tuple[DataflowGraph, typing.Mapping[ptypes.Inode, frozenset[pathlib.Path]]]: """Turn a list of accesses into a dataflow graph, by assigning a version at every access.""" @@ -65,84 +46,97 @@ class PidState(enum.IntEnum): WRITING = enum.auto() parent_pid = probe_log.get_parent_pid_map() - pid_to_state = collections.defaultdict[ptypes.Pid, PidState](lambda: PidState.READING) - last_op_in_process = dict[ptypes.Pid, hb_graph.OpNode]() + pid_to_state = dict[ptypes.Pid, PidState]() + last_op_in_process = dict[ptypes.Pid, ptypes.OpQuint]() inode_to_version = collections.defaultdict[ptypes.Inode, int](lambda: 0) inode_to_paths = collections.defaultdict[ptypes.Inode, set[pathlib.Path]](set) dataflow_graph = DataflowGraph() - def add_node(node: hb_graph.OpNode) -> None: - pid_to_state[node.pid] = PidState.READING - if program_order_predecessor := last_op_in_process.get(node.pid): - dataflow_graph.add_edge(program_order_predecessor, node) + def add_quad(quad: ptypes.OpQuad, label: str) -> None: + pid_to_state[quad.pid] = PidState.READING + if program_order_predecessor := last_op_in_process.get(quad.pid): + quint = program_order_predecessor.deduplicate(quad) + dataflow_graph.add_edge(program_order_predecessor, quint, label=label + " (from pred)") else: - if parent := parent_pid.get(node.pid): - dataflow_graph.add_edge(last_op_in_process[parent], node) + quint = ptypes.OpQuint.from_quad(quad) + if parent := parent_pid.get(quad.pid): + dataflow_graph.add_edge(last_op_in_process[parent], quint, label=label + " (from parent)") else: pass - # Found initial node of root proc - last_op_in_process[node.pid] = node + # Found initial quad of root proc + last_op_in_process[quad.pid] = quint - def ensure_state(node: hb_graph.OpNode, desired_state: PidState) -> None: - if desired_state == PidState.WRITING and pid_to_state[node.pid] == PidState.READING: + def ensure_state(quad: ptypes.OpQuad, desired_state: PidState) -> ptypes.OpQuint: + if quad.pid not in pid_to_state: + warnings.warn(ptypes.UnusualProbeLog( + f"Encountered {quad}, but there are no nodes on process {quad.pid}.", + )) + add_quad(quad, "init") + if desired_state == PidState.WRITING and pid_to_state[quad.pid] == PidState.READING: # Reading -> writing for free - pid_to_state[node.pid] = PidState.WRITING - elif desired_state == PidState.READING and pid_to_state[node.pid] == PidState.WRITING: - # Writing -> reading by starting a new node. - add_node(node) - assert pid_to_state[node.pid] == desired_state + pid_to_state[quad.pid] = PidState.WRITING + elif desired_state == PidState.READING and pid_to_state[quad.pid] == PidState.WRITING: + # Writing -> reading by starting a new quad. + add_quad(quad, "r→w") + assert pid_to_state[quad.pid] == desired_state + return last_op_in_process[quad.pid] - for access_or_node in accesses_and_nodes: - match access_or_node: - case Access(): - access = access_or_node + for access_or_quad in accesses_and_quads: + match access_or_quad: + case ptypes.Access(): + access = access_or_quad version_num = inode_to_version[access.inode] inode_to_paths[access.inode].add(access.path) version = InodeVersionNode(access.inode, version_num) next_version = InodeVersionNode(access.inode, version_num + 1) - ensure_state(access.op_node, PidState.READING if access.mode.is_side_effect_free() else PidState.WRITING) - if (op_node := last_op_in_process.get(access.op_node.pid)) is None: - warnings.warn(ptypes.UnusualProbeLog(f"Can't find last node from process {access.op_node.pid}")) - continue match access.mode: - case AccessMode.WRITE: - if access.phase == Phase.BEGIN: - dataflow_graph.add_edge(op_node, next_version) - dataflow_graph.add_edge(version, next_version) - case AccessMode.TRUNCATE_WRITE: - if access.phase == Phase.END: - dataflow_graph.add_edge(op_node, next_version) - case AccessMode.READ_WRITE: - if access.phase == Phase.BEGIN: - dataflow_graph.add_edge(version, op_node) - if access.phase == Phase.END: - dataflow_graph.add_edge(op_node, next_version) - dataflow_graph.add_edge(version, next_version) - case AccessMode.READ | AccessMode.EXEC | AccessMode.DLOPEN: - if access.phase == Phase.BEGIN: - dataflow_graph.add_edge(version, op_node) + case ptypes.AccessMode.WRITE: + if access.phase == ptypes.Phase.BEGIN: + quint = ensure_state(access.op_node, PidState.WRITING) + dataflow_graph.add_edge(quint, next_version, label="mutating write") + dataflow_graph.add_edge(version, next_version, label="mutating write") + inode_to_version[access.inode] += 1 + case ptypes.AccessMode.TRUNCATE_WRITE: + if access.phase == ptypes.Phase.END: + quint = ensure_state(access.op_node, PidState.WRITING) + dataflow_graph.add_edge(quint, next_version, label="truncating write") + inode_to_version[access.inode] += 1 + case ptypes.AccessMode.READ_WRITE: + if access.phase == ptypes.Phase.BEGIN: + quint = ensure_state(access.op_node, PidState.READING) + dataflow_graph.add_edge(version, quint, label="read & write") + if access.phase == ptypes.Phase.END: + quint = ensure_state(access.op_node, PidState.WRITING) + dataflow_graph.add_edge(quint, next_version, label="read/write") + dataflow_graph.add_edge(version, next_version, label="read/write") + inode_to_version[access.inode] += 1 + case ptypes.AccessMode.READ | ptypes.AccessMode.EXEC | ptypes.AccessMode.DLOPEN: + if access.phase == ptypes.Phase.BEGIN: + quint = ensure_state(access.op_node, PidState.READING) + dataflow_graph.add_edge(version, quint, label="read") case _: raise TypeError() - case hb_graph.OpNode(): - node = access_or_node - op_data = probe_log.get_op(*node.op_quad()).data + case ptypes.OpQuad(): + quad = access_or_quad + op_data = probe_log.get_op(quad).data match op_data: # us -> our child # Therefore, we have to be in writing mode case ops.CloneOp(): if op_data.task_type == ptypes.TaskType.TASK_PID and not (op_data.flags & os.CLONE_THREAD): - ensure_state(node, PidState.WRITING) + ensure_state(quad, PidState.WRITING) case ops.SpawnOp(): - ensure_state(node, PidState.WRITING) + ensure_state(quad, PidState.WRITING) case ops.InitExecEpochOp(): - add_node(node) + add_quad(quad, "init") + inode_to_paths2 = {inode: frozenset(paths) for inode, paths in inode_to_paths.items()} return dataflow_graph, inode_to_paths2 def hb_graph_to_dataflow_graph2( probe_log: ptypes.ProbeLog, - hbg: hb_graph.HbGraph, + hbg: ptypes.HbGraph, check: bool = False, ) -> tuple[DataflowGraph, typing.Mapping[ptypes.Inode, frozenset[pathlib.Path]]]: accesses = list(hb_graph_to_accesses(probe_log, hbg)) @@ -160,8 +154,8 @@ def combine_indistinguishable_inodes( else: warnings.warn(ptypes.UnusualProbeLog("Dataflow graph is cyclic")) def same_neighbors( - node0: hb_graph.OpNode | InodeVersionNode, - node1: hb_graph.OpNode | InodeVersionNode, + node0: ptypes.OpQuad | InodeVersionNode, + node1: ptypes.OpQuad | InodeVersionNode, ) -> bool: return ( isinstance(node0, InodeVersionNode) @@ -172,10 +166,10 @@ def same_neighbors( and frozenset(dataflow_graph.successors(node0)) == frozenset(dataflow_graph.successors(node1)) ) - def node_mapper(node_set: frozenset[hb_graph.OpNode | InodeVersionNode]) -> hb_graph.OpNode | frozenset[InodeVersionNode]: + def node_mapper(node_set: frozenset[ptypes.OpQuint | InodeVersionNode]) -> ptypes.OpQuint | frozenset[InodeVersionNode]: first_node = next(iter(node_set)) - if isinstance(first_node, hb_graph.OpNode): - assert all(isinstance(node, hb_graph.OpNode) for node in node_set) + if isinstance(first_node, ptypes.OpQuint): + assert all(isinstance(node, ptypes.OpQuint) for node in node_set) return first_node else: assert all(isinstance(node, InodeVersionNode) for node in node_set) @@ -241,16 +235,21 @@ def label_nodes( ) -> None: count = dict[tuple[ptypes.Pid, ptypes.ExecNo], int]() root_pid = probe_log.get_root_pid() - for node in tqdm.tqdm( - networkx.topological_sort(dataflow_graph), - total=len(dataflow_graph), - desc="Labelling DFG nodes", - ): + if networkx.is_directed_acyclic_graph(dataflow_graph): + nodes = list(networkx.topological_sort(dataflow_graph)) + cycle = [] + else: + nodes = list(dataflow_graph.nodes()) + cycle = list(networkx.find_cycle(dataflow_graph)) + warnings.warn(ptypes.UnusualProbeLog( + "Dataflow graph contains a cycle (marked in red).", + )) + for node in nodes: data = dataflow_graph.nodes(data=True)[node] match node: - case hb_graph.OpNode(): + case ptypes.OpQuad(): data["shape"] = "oval" - op = probe_log.get_op(node.pid, node.exec_no, node.pid.main_thread(), 0) + op = probe_log.get_op(node) if node.op_no == 0: count[(node.pid, node.exec_no)] = 1 if node.exec_no != 0: @@ -273,7 +272,7 @@ def label_nodes( data["label"] = "" if (node.pid, node.exec_no) not in count: warnings.warn(ptypes.UnusualProbeLog( - f"{node.pid, node.exec_no} never counted before" + f"{node.pid, node.exec_no} never counted before", )) count[(node.pid, node.exec_no)] = 99 count[(node.pid, node.exec_no)] += 1 @@ -291,7 +290,7 @@ def shorten_path(input: pathlib.Path) -> str: inode_labels = [] for inode_version in inode_versions[:max_inodes_per_set]: inode_label = [] - inode_label.append(f"{inode_version.inode.number} v{inode_version.version}") + inode_label.append(f"{inode_version.inode} v{inode_version.version}") paths = inodes_to_path.get(inode_version.inode, frozenset[pathlib.Path]()) for path in sorted(paths, key=lambda path: len(str(path)))[:max_paths_per_inode]: inode_label.append(shorten_path(path)) @@ -301,3 +300,5 @@ def shorten_path(input: pathlib.Path) -> str: data["label"] = "\n".join(inode_labels) data["shape"] = "rectangle" data["id"] = str(hash(node)) + for a, b in cycle: + dataflow_graph.edges[a, b]["color"] = "red" # type: ignore diff --git a/probe_py/probe_py/file_closure.py b/probe_py/probe_py/file_closure.py index af39ecc6..768a40a1 100644 --- a/probe_py/probe_py/file_closure.py +++ b/probe_py/probe_py/file_closure.py @@ -10,6 +10,7 @@ import warnings import pathlib import typing +from . import ptypes from .ptypes import ProbeLog, initial_exec_no, InodeVersion, Pid from .ops import Path, ChdirOp, OpenOp, CloseOp, InitExecEpochOp, ExecOp, Op from .consts import AT_FDCWD @@ -230,7 +231,9 @@ def copy_file_closure( console.print(f"Skipping {resolved_path}") elif resolved_path.exists(): if ino_ver is not None and InodeVersion.from_local_path(resolved_path) != ino_ver: - warnings.warn(f"{resolved_path} changed in between the time of `probe record` and now.") + warnings.warn(ptypes.UnusualProbeLog( + f"{resolved_path} changed in between the time of `probe record` and now.", + )) if resolved_path.is_dir(): destination_path.mkdir(exist_ok=True, parents=True) elif copy: diff --git a/probe_py/probe_py/graph_utils.py b/probe_py/probe_py/graph_utils.py index 40a163eb..01241da5 100644 --- a/probe_py/probe_py/graph_utils.py +++ b/probe_py/probe_py/graph_utils.py @@ -1,7 +1,8 @@ from __future__ import annotations -import collections.abc -import tqdm +import abc +import collections import dataclasses +import datetime import itertools import typing import pathlib @@ -19,7 +20,7 @@ @dataclasses.dataclass(frozen=True) class Segment(typing.Generic[_CoNode]): - dag_tc: LazyTransitiveClosure[_CoNode] + dag_tc: ReachabilityOracle[_CoNode] upper_bound: frozenset[_CoNode] lower_bound: frozenset[_CoNode] @@ -37,12 +38,12 @@ def __post_init__(self) -> None: assert not unbounded, \ f"{unbounded} in self.lower_bound is not a descendant of any in {self.upper_bound=}" - def nodes(self) -> frozenset[_CoNode]: - return self.dag_tc.between(self.upper_bound, self.lower_bound) + def nodes(self) -> collections.abc.Iterable[_CoNode]: + return self.dag_tc.nodes_between(self.upper_bound, self.lower_bound) def overlaps(self, other: Segment[_CoNode]) -> bool: assert self.dag_tc is other.dag_tc - return bool(self.nodes() & other.nodes()) + return bool(frozenset(self.nodes()) & frozenset(other.nodes())) @staticmethod def union(segments: typing.Sequence[Segment[_CoNode]]) -> Segment[_CoNode]: @@ -83,7 +84,7 @@ def map_nodes( graph: networkx.DiGraph[_Node], check: bool = True, ) -> networkx.DiGraph[_Node2]: - dct = {node: function(node) for node in tqdm.tqdm(graph.nodes(), desc="nodes")} + dct = {node: function(node) for node in graph.nodes()} assert util.all_unique(dct.values()), util.duplicates(dct.values()) return networkx.relabel_nodes(graph, dct) @@ -95,8 +96,10 @@ def serialize_graph( cluster_labels: collections.abc.Mapping[str, str] = {}, ) -> None: if name_mapper is None: - nodes_data = graph.nodes(data=True) - name_mapper = typing.cast(typing.Callable[[_Node], str], lambda node: nodes_data[node].get("id", str(node))) + name_mapper = typing.cast( + typing.Callable[[_Node], str], + lambda node: graph.nodes(data=True)[node].get("id", str(node)), + ) graph2 = map_nodes(name_mapper, graph) pydot_graph = networkx.drawing.nx_pydot.to_pydot(graph2) @@ -150,29 +153,47 @@ def replace(digraph: networkx.DiGraph[_Node], old: _Node, new: _Node) -> None: def bfs_with_pruning( digraph: networkx.DiGraph[_Node], start: _Node, -) -> typing.Generator[_Node, bool, None]: - """BFS but send False to prune this branch""" + left_to_right: bool = False +) -> typing.Generator[_Node | None, bool | None, None]: + """BFS but send False to prune this branch + + traversal = bfs_with_pruning + for node in traversal: + # work on node + traversal.send(condition) # send True to descend or False to prune + """ queue = [start] while queue: node = queue.pop() + # When we yield, we do the body of the client's for-loop with "node" + # Until they do bfs.send(...) + # At which point we resume continue_with_children = yield node + # Now we resumed. + # When we yield this time, the caller's bfs.send(...) returns "None" + should_be_none = yield None + # Now the for-loop has wrapped around and we are back here. + assert should_be_none is None if continue_with_children: - queue.extend(digraph.successors(node)) + children = list(digraph.successors(node)) + if left_to_right: + children = children[::-1] + queue.extend(children) -def get_root(dag: networkx.DiGraph[_Node]) -> _Node: - roots = get_roots(dag) - if len(roots) != 1: - raise RuntimeError(f"No roots or too many roots: {roots}") - else: - return roots[0] +def get_sources(dag: networkx.DiGraph[_Node]) -> list[_Node]: + return [ + node + for node in dag.nodes() + if dag.in_degree(node) == 0 + ] -def get_roots(dag: networkx.DiGraph[_Node]) -> list[_Node]: +def get_sinks(dag: networkx.DiGraph[_Node]) -> list[_Node]: return [ node for node in dag.nodes() - if dag.in_degree(node) == 0 + if dag.out_degree(node) == 0 ] @@ -185,142 +206,290 @@ def randomly_sample_edges(inp: networkx.DiGraph[_Node], factor: float, seed: int return out -class LazyTransitiveClosure(typing.Generic[_Node]): - _dag: networkx.DiGraph[_Node] - _topological_generations: list[list[_Node]] - _rank: typing.Mapping[_Node, int] - _descendants: dict[_Node, tuple[int, set[_Node]]] - - def __init__(self, dag: networkx.DiGraph[_Node]) -> None: - self._dag = dag - self._topological_generations = [ - list(layer) - for layer in networkx.topological_generations(self._dag) - ] - self._rank = { - node: layer_no - for layer_no, layer in enumerate(self._topological_generations) - for node in layer - } - self._descendants = {} - - def between(self, upper_bounds: frozenset[_Node], lower_bounds: frozenset[_Node]) -> frozenset[_Node]: - max_rank = max(self._rank[lower_bound] for lower_bound in lower_bounds) - descendants = set().union(*( - self.descendants(upper_bound, max_rank) - for upper_bound in upper_bounds - )) - return frozenset({ - node - for node in descendants - if self.descendants(node, max_rank) & lower_bounds - }) +class ReachabilityOracle(abc.ABC, typing.Generic[_Node]): + """ + This datastructure answers reachability queries, is A reachable from B in dag. - def non_ancestors(self, candidates: frozenset[_Node], lower_bound: frozenset[_Node]) -> frozenset[_Node]: - max_rank = max(self._rank[bound] for bound in lower_bound) - return frozenset({ - candidate - for candidate in candidates - lower_bound - if not self.descendants(candidate, max_rank) & lower_bound - }) + If you had only 1 reachability query, it would be best to DFS the graph from B, looking for A. + DFS might have to traverse the whole graph and touch every edge, O(V+E). + In fact, when A is _not_ a descendant of B (but we don't know that yet), if B is high up, then DFS approaches its worst case. + Let's say you have N queries, resulting in O(N(V+E)) to complete all queries. - def non_descendants(self, candidates: frozenset[_Node], upper_bound: frozenset[_Node]) -> frozenset[_Node]: - max_rank = max(self._rank[candidate] for candidate in candidates) - descendants = set().union(*( - self.descendants(upper_bound, max_rank) - for upper_bound in upper_bound - )) - return frozenset(candidates - descendants - upper_bound) + If N gets to be larger than V, you're better off pre-computing reachability ahead of time. + Because DFS tells you "all of the Bs descendent from A", we need to do DFS for each node as a source, resulting in, O(V(V+E)). + This is conveniently implemented as [`networkx.transitive_closure`][source code]. - def is_antichain(self, nodes: typing.Iterable[_Node]) -> bool: - max_rank = max(self._rank[node] for node in nodes) + [source code]: https://networkx.org/documentation/stable/_modules/networkx/algorithms/dag.html#transitive_closure + + However, if V is on the order of 10^4 (E must be at least V for a connected graph), then V^2 could be terribly slow. + There are more efficient datastructures for answering N queries, often involving some kind of preprocessing. + This class encapsulate the preprocessing datastructure, and offers a method to answer reachability. + """ + + @staticmethod + @abc.abstractmethod + def create(dag: networkx.DiGraph[_Node]) -> ReachabilityOracle[_Node]: + ... + + @abc.abstractmethod + def is_reachable(self, u: _Node, v: _Node) -> bool: + pass + + @abc.abstractmethod + def nodes_between( + self, + upper_bounds: collections.abc.Iterable[_Node], + lower_bounds: collections.abc.Iterable[_Node], + ) -> collections.abc.Iterable[_Node]: + ... + + @abc.abstractmethod + def add_edge(self, u: _Node, v: _Node) -> None: + """Keep datastructure up-to-date""" + + def is_antichain(self, nodes: collections.abc.Iterable[_Node]) -> bool: return all( - node0 not in self.descendants(node1, max_rank) and node1 not in self.descendants(node0, max_rank) + not self.is_reachable(node0, node1) for node0, node1 in itertools.combinations(nodes, 2) ) - def get_bottommost(self, nodes: collections.abc.Set[_Node]) -> frozenset[_Node]: - max_rank = max(self._rank[node] for node in nodes) - bottommost_nodes = set[_Node]() - sorted_nodes = self.sorted(nodes)[::-1] - for node in sorted_nodes: - if not self.descendants(node, max_rank) & bottommost_nodes: - bottommost_nodes.add(node) - return frozenset(bottommost_nodes) - - def get_uppermost(self, nodes: collections.abc.Set[_Node]) -> frozenset[_Node]: - max_rank = max(self._rank[node] for node in nodes) + def sorted(self, nodes: collections.abc.Iterable[_Node]) -> collections.abc.Sequence[_Node]: + dag: networkx.DiGraph[_Node] = networkx.DiGraph() + dag.add_nodes_from(nodes) + dag.add_edges_from([ + (source, target) + for source in nodes + for target in nodes + if self.is_reachable(source, target) + and source != target + ]) + return list(networkx.topological_sort(dag)) + + def get_uppermost(self, nodes: collections.abc.Iterable[_Node]) -> frozenset[_Node]: uppermost_nodes = set[_Node]() covered_nodes = set[_Node]() sorted_nodes = self.sorted(nodes) - for node in sorted_nodes: - if node not in covered_nodes: - uppermost_nodes.add(node) - covered_nodes.update(self.descendants(node, max_rank)) + for i, candidate in enumerate(sorted_nodes): + if candidate not in covered_nodes: + uppermost_nodes.add(candidate) + covered_nodes.update( + descendant + for descendant in sorted_nodes[i+1:] + if self.is_reachable(candidate, descendant) + ) + assert all( + any( + uppermost_node == node or self.is_reachable(uppermost_node, node) + for uppermost_node in uppermost_nodes) + for node in nodes + ) + assert not any( + self.is_reachable(a, b) + for a in uppermost_nodes + for b in uppermost_nodes + ) return frozenset(uppermost_nodes) - def sorted(self, nodes: typing.Iterable[_Node]) -> list[_Node]: - return sorted(nodes, key=self._rank.__getitem__) - - def is_reachable(self, src: _Node, dst: _Node) -> bool: - return dst in self.descendants(src, self._rank[dst]) + def get_bottommost(self, nodes: collections.abc.Iterable[_Node]) -> frozenset[_Node]: + bottom_nodes = set[_Node]() + covered_nodes = set[_Node]() + sorted_nodes = self.sorted(nodes)[::-1] + for i, candidate in enumerate(sorted_nodes): + if candidate not in covered_nodes: + bottom_nodes.add(candidate) + covered_nodes.update( + ancestor + for ancestor in sorted_nodes[i+1:] + if self.is_reachable(ancestor, candidate) + ) + assert all( + any( + bottom_node == node or self.is_reachable(node, bottom_node) + for bottom_node in bottom_nodes + ) + for node in nodes + ) + assert not any( + self.is_reachable(a, b) + for a in bottom_nodes + for b in bottom_nodes + ) + return frozenset(bottom_nodes) - def descendants(self, src: _Node, rank: int) -> frozenset[_Node]: - # Read the code as if descendants is True. - # Note that everythign is exactly reversed if descendants was false. - descendants = self._descendants - successors = self._dag.successors - def in_range(input_rank: int) -> bool: - return input_rank <= rank + def non_ancestors( + self, + candidates: collections.abc.Iterable[_Node], + lower_bounds: collections.abc.Iterable[_Node], + ) -> collections.abc.Iterable[_Node]: + return frozenset({ + candidate + for candidate in candidates + if not any( + self.is_reachable(candidate, lower_bound) + for lower_bound in lower_bounds + ) + }) - # stack will hold _paths from src_ not nodes - stack = [ - [src] - ] + def non_descendants( + self, + candidates: collections.abc.Iterable[_Node], + upper_bounds: collections.abc.Iterable[_Node], + ) -> collections.abc.Iterable[_Node]: + return frozenset({ + candidate + for candidate in candidates + if not any( + self.is_reachable(upper_bound, candidate) + for upper_bound in upper_bounds + ) + }) - # Do DFS - while stack: - path = stack[-1] - assert path[0] == src - node = path[-1] - if node in descendants and in_range(descendants[node][0]): - # already pre-computed, no work to do. - stack.pop() +@dataclasses.dataclass(frozen=True) +class PrecomputedReachabilityOracle(ReachabilityOracle[_Node]): + dag_tc: networkx.DiGraph[_Node] + @staticmethod + def create(dag: networkx.DiGraph[_Node]) -> PrecomputedReachabilityOracle[_Node]: + start = datetime.datetime.now() + print("Computing transitive closure") + ret = typing.cast(networkx.DiGraph[_Node], networkx.transitive_closure(dag)) + duration = datetime.datetime.now() - start + print(f"Done computing in {duration.total_seconds():.1f}sec") + return PrecomputedReachabilityOracle(ret) + + def is_reachable(self, u: _Node, v: _Node) -> bool: + return v in self.dag_tc.successors(u) + + def nodes_between( + self, + upper_bounds: collections.abc.Iterable[_Node], + lower_bounds: collections.abc.Iterable[_Node], + ) -> collections.abc.Iterable[_Node]: + raise NotImplementedError() + + def add_edge(self, source: _Node, target: _Node) -> None: + if target not in self.dag_tc.successors(source): + for descendant_of_source in [*self.dag_tc.successors(source), source]: + for descendant_of_target in [*self.dag_tc.successors(target), target]: + self.dag_tc.add_edge(descendant_of_source, descendant_of_target) + + +def get_faces( + planar_graph: networkx.PlanarEmbedding[_Node], # type: ignore +) -> frozenset[tuple[_Node, ...]]: + faces = set() + covered_half_edges = set() + for half_edge in planar_graph.edges(): + if half_edge not in covered_half_edges: + covered_half_edges.add(half_edge) + face = planar_graph.traverse_face(*half_edge) + faces.add(tuple(face)) + if len(face) > 1: + for a, b in [*zip(face[:-1], face[1:]), (face[-1], face[0])]: + covered_half_edges.add((a, b)) + return frozenset(faces) + + +def add_edge_without_cycle( + dag: networkx.DiGraph[_Node], + source: _Node, + target: _Node, + reachability_oracle: ReachabilityOracle[_Node] | None = None, +) -> collections.abc.Sequence[tuple[_Node, _Node]]: + """ + Add an edge from source to the earliest descendants of target without creating a cycle. + + Consider the graph: + + 0 -> 10, 20, 30; + 10 -> 11; + 20 -> 21; + 30 -> 31; + + If we add the edge 31 -> 0, that would create a cycle. + So we look at the children of 0. + We add the edge 31 -> 10. + We add the edge 31 -> 20. + We don't add 31 -> 30, because that would create a cycle. + We recurse into 31's aunts and uncles, add edges to them, etc. + """ + + if reachability_oracle is None: + assert networkx.is_directed_acyclic_graph(dag) + reachability_oracle = PrecomputedReachabilityOracle.create(dag) + + if reachability_oracle.is_reachable(source, target): + # No cycle would be made anyway. + # Easy. + return [(source, target)] + else: + edges = [] + # Start from target + # See if each descendant can be used as a proxy for target. + # I.e., source -> proxy_target. + # If not, we will have to recurse into its children until a suitable target is found or the original source is found. + bfs = bfs_with_pruning(dag, target) + for proxy_target in bfs: + assert proxy_target is not None + if reachability_oracle.is_reachable(proxy_target, source): + # Upstream of source + # An edge here would create a cycle. + # We will recurse into the children to find a suitable proxy target. + bfs.send(True) + elif reachability_oracle.is_reachable(source, proxy_target) or proxy_target == source: + # Downstream of target (or equal to target). + # Time to stop. + bfs.send(False) else: - # Not already precomputed - # Recurse into successors - # But only those in range - successors_in_range = { - successor - for successor in successors(node) - if in_range(self._rank[successor]) - } - noncomputed_successors_in_range = { - successor - for successor in successors_in_range - if successor not in descendants[successor][1] or not in_range(descendants[successor][0]) - } - if noncomputed_successors_in_range: - for successor in noncomputed_successors_in_range: - stack.append([*path, successor]) - else: - descendants[node] = ( - rank, - set.union(*(descendants[successor][1] for successor in successors_in_range)) if successors_in_range else set(), - ) - - return frozenset(descendants[node][1]) - - -def dag_transitive_closure(dag: networkx.DiGraph[_Node]) -> networkx.DiGraph[_Node]: - tc: networkx.DiGraph[_Node] = networkx.DiGraph() - node_order = list(networkx.topological_sort(dag))[::-1] - for src in tqdm.tqdm(node_order, desc="TC"): - tc.add_node(src) - for child in dag.successors(src): - tc.add_edge(src, child) - for grandchild in dag.successors(child): - tc.add_edge(src, grandchild) - return tc + # Neither upstpream nor downstream. + # We can put an edge here and quit. + edges.append((source, proxy_target)) + bfs.send(False) + # checking: + dag2 = dag.copy() + dag2.add_edges_from(edges) + assert networkx.is_directed_acyclic_graph(dag2) + return edges + + +def splice_out_nodes( + input_dag: networkx.DiGraph[_Node], + should_splice: typing.Callable[[_Node], bool], +) -> networkx.DiGraph[_Node]: + output_dag = input_dag.copy() + for node in list(input_dag.nodes()): + if should_splice(node): + output_dag.add_edges_from([ + (predecessor, successor) + for predecessor in output_dag.predecessors(node) + for successor in output_dag.predecessors(node) + if predecessor != node and successor != node + ]) + output_dag.remove_node(node) + return output_dag + + +def topological_sort_depth_first( + dag: networkx.DiGraph[_Node], + score_children: typing.Callable[[_Node, _Node], int] = lambda _parent, _child: 0, +) -> typing.Iterable[_Node]: + """Topological sort that breaks ties by depth first, and then by lowest child score.""" + queue = util.PriorityQueue[_Node, tuple[int, int]]( + (node, (dag.in_degree(node), 0)) + for node in dag.nodes() + ) + counter = 0 + while queue: + (in_degree, tie_breaker), node = queue.pop() + if in_degree == 0: + yield node + # Since we handled the parent, we essentially removed it from the graph + # decrementing the in-degree of its children by one. + # To make it be depth first, we make it "win" all ties, among currently existing entries. + for child in sorted(dag.successors(node), key=lambda child: score_children(node, child)): + in_degree, tie_breaker = queue[child] + queue[child] = (in_degree - 1, -counter) + else: + raise RuntimeError(f"Cycle exists and includes {node}") + counter += 1 diff --git a/probe_py/probe_py/hb_graph.py b/probe_py/probe_py/hb_graph.py index 1623788d..960cfbae 100644 --- a/probe_py/probe_py/hb_graph.py +++ b/probe_py/probe_py/hb_graph.py @@ -1,14 +1,14 @@ import os -import dataclasses import shlex import textwrap import typing import warnings import networkx import tqdm -from .ptypes import TaskType, Pid, ExecNo, Tid, ProbeLog, initial_exec_no, InvalidProbeLog, InodeVersion, UnusualProbeLog -from .ops import CloneOp, ExecOp, WaitOp, OpenOp, SpawnOp, InitExecEpochOp, InitThreadOp, Op, CloseOp, DupOp +from .ptypes import TaskType, Pid, ExecNo, Tid, ProbeLog, initial_exec_no, InvalidProbeLog, InodeVersion, OpQuad, HbGraph +from .ops import CloneOp, ExecOp, WaitOp, OpenOp, SpawnOp, InitExecEpochOp, InitThreadOp, Op, CloseOp, DupOp, StatOp from . import graph_utils +from . import ptypes """ HbGraph stands for "Happened-Before graph". @@ -22,36 +22,13 @@ """ -@dataclasses.dataclass(frozen=True) -class OpNode: - pid: Pid - exec_no: ExecNo - tid: Tid - op_no: int - - def thread_triple(self) -> tuple[Pid, ExecNo, Tid]: - return (self.pid, self.exec_no, self.tid) - - def op_quad(self) -> tuple[Pid, ExecNo, Tid, int]: - return (self.pid, self.exec_no, self.tid, self.op_no) - - def __str__(self) -> str: - return f"PID {self.pid} Exec {self.exec_no} TID {self.tid} op {self.op_no}" - - -if typing.TYPE_CHECKING: - HbGraph: typing.TypeAlias = networkx.DiGraph[OpNode] -else: - HbGraph = networkx.DiGraph - - def probe_log_to_hb_graph(probe_log: ProbeLog) -> HbGraph: hb_graph = HbGraph() _create_program_order_edges(probe_log, hb_graph) # Hook up synchronization edges - for node in tqdm.tqdm(hb_graph.nodes(), "sync edges"): + for node in hb_graph.nodes(): _create_clone_edges(node, probe_log, hb_graph) _create_exec_edges(node, probe_log, hb_graph) _create_spawn_edges(node, probe_log, hb_graph) @@ -67,38 +44,38 @@ def probe_log_to_hb_graph(probe_log: ProbeLog) -> HbGraph: def retain_only( probe_log: ProbeLog, full_hb_graph: HbGraph, - retain_node_predicate: typing.Callable[[OpNode, Op], bool], + retain_node_predicate: typing.Callable[[OpQuad, Op], bool], ) -> HbGraph: """Retains only nodes satisfying the predicate.""" reduced_hb_graph = HbGraph() - last_in_process = dict[tuple[Pid, ExecNo, Tid], OpNode]() - incoming_to_process = dict[tuple[Pid, ExecNo, Tid], list[tuple[OpNode, typing.Mapping[str, typing.Any]]]]() + last_in_process = dict[tuple[Pid, ExecNo, Tid], OpQuad]() + incoming_to_process = dict[tuple[Pid, ExecNo, Tid], list[tuple[OpQuad, typing.Mapping[str, typing.Any]]]]() for node in tqdm.tqdm( networkx.topological_sort(full_hb_graph), total=len(full_hb_graph), desc="retain", ): - node_triple = (node.pid, node.exec_no, node.tid) + thread = node.thread_triple() # If node satisfies predicate, copy node into new graph - if retain_node_predicate(node, probe_log.get_op(*node.op_quad())): + if retain_node_predicate(node, probe_log.get_op(node)): node_data = full_hb_graph.nodes(data=True)[node] reduced_hb_graph.add_node(node, **node_data) # Add link from previous node in this process, if any # Note that iteration is in topo order, # so this node happens-after the node of previous iterations. - if previous_node := last_in_process.get(node_triple): + if previous_node := last_in_process.get(thread): reduced_hb_graph.add_edge(previous_node, node) - last_in_process[node_triple] = node + last_in_process[thread] = node # Link up any out-of-process predecessors we accumulated up to this node - incoming = incoming_to_process.get(node_triple, []) + incoming = incoming_to_process.get(thread, []) for (predecessor, edge_data) in incoming: reduced_hb_graph.add_edge(predecessor, node, **edge_data) if incoming: - del incoming_to_process[node_triple] + del incoming_to_process[thread] # Accumulate out-of-process predecessors # Since we iterate in topo order, @@ -111,10 +88,15 @@ def retain_only( # none of the prior nodes in this process were retained, # so the edge doesn't synchronize any retained nodes. # In such case, we don't need to create an edge. - if successor_triple != node_triple and (previous_node := last_in_process.get(node_triple)): + if successor_triple != thread and (previous_node := last_in_process.get(thread)): edge_data = full_hb_graph.get_edge_data(node, successor) incoming_to_process.setdefault(successor_triple, []).append((previous_node, edge_data)) + for thread, incoming in incoming_to_process.items(): + for node, edge_data in incoming: + if predecessor2 := last_in_process.get(thread): + reduced_hb_graph.add_edge(predecessor2, node, **edge_data) + validate_hb_graph(reduced_hb_graph, False) return reduced_hb_graph @@ -123,17 +105,24 @@ def retain_only( def validate_hb_graph(hb_graph: HbGraph, validate_roots: bool) -> None: if not networkx.is_directed_acyclic_graph(hb_graph): cycle = list(networkx.find_cycle(hb_graph)) - warnings.warn(UnusualProbeLog(f"Found a cycle in hb graph: {cycle}")) + warnings.warn(ptypes.UnusualProbeLog( + f"Found a cycle in hb graph: {cycle}", + )) if validate_roots: - graph_utils.get_root(hb_graph) - # TODO: Check that root pid and/or parent-pid is as expected. + sources = graph_utils.get_sources(hb_graph) + if len(sources) > 1: + warnings.warn(ptypes.UnusualProbeLog( + f"Too many sources {sources}" + )) + + # TODO: Check that root pid and/or parent-pid is as expected. def _create_program_order_edges(probe_log: ProbeLog, hb_graph: HbGraph) -> None: if not probe_log.processes: raise InvalidProbeLog("No processes tracked") - for pid, process in tqdm.tqdm(probe_log.processes.items(), "processes program order"): + for pid, process in probe_log.processes.items(): if not process.execs: raise InvalidProbeLog(f"No exec epochs tracked for pid {pid}") for exec_no, exec_epoch in process.execs.items(): @@ -143,7 +132,7 @@ def _create_program_order_edges(probe_log: ProbeLog, hb_graph: HbGraph) -> None: if not thread.ops: raise InvalidProbeLog(f"No ops tracked for thread {tid}") nodes = [ - OpNode(pid, exec_no, tid, op_no) + OpQuad(pid, exec_no, tid, op_no) for op_no, op in enumerate(thread.ops) ] assert nodes @@ -154,28 +143,28 @@ def _create_program_order_edges(probe_log: ProbeLog, hb_graph: HbGraph) -> None: hb_graph.add_edges_from(zip(nodes[:-1], nodes[1:])) -def _create_clone_edges(node: OpNode, probe_log: ProbeLog, hb_graph: HbGraph) -> None: - op = probe_log.get_op(*node.op_quad()) +def _create_clone_edges(node: OpQuad, probe_log: ProbeLog, hb_graph: HbGraph) -> None: + op = probe_log.get_op(node) if isinstance(op.data, CloneOp) and op.data.ferrno == 0: match op.data.task_type: case TaskType.TASK_TID: target_tid = Tid(op.data.task_id) if target_tid not in probe_log.processes[node.pid].execs[node.exec_no].threads: - warnings.warn(UnusualProbeLog( - f"Clone points to a thread {target_tid} we didn't track" + warnings.warn(ptypes.UnusualProbeLog( + f"Clone ({node}) points to a thread {target_tid} we didn't track" )) else: - target = OpNode(node.pid, node.exec_no, target_tid, 0) + target = OpQuad(node.pid, node.exec_no, target_tid, 0) assert hb_graph.has_node(target) hb_graph.add_edge(node, target) case TaskType.TASK_PID: target_pid = Pid(op.data.task_id) if target_pid not in probe_log.processes: - warnings.warn(UnusualProbeLog( - f"Clone points to a process {target_pid} we didn't track {probe_log.processes.keys()}" + warnings.warn(ptypes.UnusualProbeLog( + f"Clone ({node}) points to a process {target_pid} we didn't track {probe_log.processes.keys()}" )) else: - target = OpNode(target_pid, initial_exec_no, target_pid.main_thread(), 0) + target = OpQuad(target_pid, initial_exec_no, target_pid.main_thread(), 0) assert hb_graph.has_node(target) hb_graph.add_edge(node, target) case TaskType.TASK_PTHREAD | TaskType.TASK_ISO_C_THREAD: @@ -192,36 +181,40 @@ def get_first_task_nodes( task_type: int, task_id: int, reverse: bool, -) -> list[OpNode]: +) -> list[OpQuad]: targets = [] for tid, thread in probe_log.processes[pid].execs[exec_no].threads.items(): for op_no, other_op in enumerate(reversed(thread.ops) if reverse else thread.ops): if (task_type == TaskType.TASK_PTHREAD and other_op.pthread_id == task_id) or \ (task_type == TaskType.TASK_ISO_C_THREAD and other_op.iso_c_thread_id == task_id): - targets.append(OpNode(pid, exec_no, tid, op_no)) + targets.append(OpQuad(pid, exec_no, tid, op_no)) break return targets -def _create_wait_edges(node: OpNode, probe_log: ProbeLog, hb_graph: HbGraph) -> None: - op = probe_log.get_op(*node.op_quad()) +def _create_wait_edges(node: OpQuad, probe_log: ProbeLog, hb_graph: HbGraph) -> None: + op = probe_log.get_op(node) if isinstance(op.data, WaitOp) and op.data.ferrno == 0: match op.data.task_type: case TaskType.TASK_TID: target_tid = Tid(op.data.task_id) if target_tid not in probe_log.processes[node.pid].execs[node.exec_no].threads: - warnings.warn(f"Wait ({node}) points to a thread {target_tid} we didn't track") + warnings.warn(ptypes.UnusualProbeLog( + f"Wait ({node}) points to a thread {target_tid} we didn't track", + )) else: - target = OpNode(node.pid, node.exec_no, target_tid, len(probe_log.processes[node.pid].execs[node.exec_no].threads[target_tid].ops) - 1) + target = OpQuad(node.pid, node.exec_no, target_tid, len(probe_log.processes[node.pid].execs[node.exec_no].threads[target_tid].ops) - 1) hb_graph.add_edge(target, node) case TaskType.TASK_PID: target_pid = Pid(op.data.task_id) if target_pid not in probe_log.processes: - warnings.warn(f"Wait ({node}) points to a process {target_pid} we didn't track") + warnings.warn(ptypes.UnusualProbeLog( + f"Wait ({node}) points to a process {target_pid} we didn't track", + )) else: last_exec_no = max(probe_log.processes[target_pid].execs.keys()) last_op_no = len(probe_log.processes[target_pid].execs[last_exec_no].threads[target_pid.main_thread()].ops) - 1 - target = OpNode(target_pid, last_exec_no, target_pid.main_thread(), last_op_no) + target = OpQuad(target_pid, last_exec_no, target_pid.main_thread(), last_op_no) assert hb_graph.has_node(target) hb_graph.add_edge(target, node) case TaskType.TASK_PTHREAD | TaskType.TASK_ISO_C_THREAD: @@ -231,30 +224,30 @@ def _create_wait_edges(node: OpNode, probe_log: ProbeLog, hb_graph: HbGraph) -> hb_graph.add_edge(target, node) -def _create_exec_edges(node: OpNode, probe_log: ProbeLog, hb_graph: HbGraph) -> None: - op = probe_log.get_op(*node.op_quad()) +def _create_exec_edges(node: OpQuad, probe_log: ProbeLog, hb_graph: HbGraph) -> None: + op = probe_log.get_op(node) if isinstance(op.data, ExecOp) and op.data.ferrno == 0: next_exec_no = node.exec_no.next() if next_exec_no not in probe_log.processes[node.pid].execs: - warnings.warn(UnusualProbeLog( - f"Exec ({node}) points to an exec epoch {next_exec_no} we didn't track" + warnings.warn(ptypes.UnusualProbeLog( + f"Exec points to an exec epoch {next_exec_no} we didn't track" )) else: - target = OpNode(node.pid, next_exec_no, node.pid.main_thread(), 0) + target = OpQuad(node.pid, next_exec_no, node.pid.main_thread(), 0) assert hb_graph.has_node(target) hb_graph.add_edge(node, target) -def _create_spawn_edges(node: OpNode, probe_log: ProbeLog, hb_graph: HbGraph) -> None: - op = probe_log.get_op(*node.op_quad()) +def _create_spawn_edges(node: OpQuad, probe_log: ProbeLog, hb_graph: HbGraph) -> None: + op = probe_log.get_op(node) if isinstance(op.data, SpawnOp) and op.data.ferrno == 0: child_pid = Pid(op.data.child_pid) if child_pid not in probe_log.processes: - warnings.warn(UnusualProbeLog( + warnings.warn(ptypes.UnusualProbeLog( f"Spawn ({node}) points to a pid {child_pid} we didn't track" )) else: - target = OpNode(child_pid, initial_exec_no, child_pid.main_thread(), 0) + target = OpQuad(child_pid, initial_exec_no, child_pid.main_thread(), 0) assert hb_graph.has_node(target) hb_graph.add_edge(node, target) @@ -264,34 +257,38 @@ def _create_other_thread_edges(probe_log: ProbeLog, hb_graph: HbGraph) -> None: for pid, process in probe_log.processes.items(): for exec_no, exec_epoch in process.execs.items(): for tid, thread in exec_epoch.threads.items(): - first_op_main_thread = OpNode(pid, exec_no, pid.main_thread(), 0) - last_op_main_thread = OpNode(pid, exec_no, pid.main_thread(), len(exec_epoch.threads[pid.main_thread()].ops) - 1) + first_op_main_thread = OpQuad(pid, exec_no, pid.main_thread(), 0) + last_op_main_thread = OpQuad(pid, exec_no, pid.main_thread(), len(exec_epoch.threads[pid.main_thread()].ops) - 1) if tid != pid.main_thread(): - first_op = OpNode(pid, exec_no, tid, 0) - last_op = OpNode(pid, exec_no, tid, len(thread.ops) - 1) + first_op = OpQuad(pid, exec_no, tid, 0) + last_op = OpQuad(pid, exec_no, tid, len(thread.ops) - 1) if len(list(hb_graph.predecessors(first_op))) == 0: hb_graph.add_edge(first_op_main_thread, first_op) if last_op_main_thread != first_op_main_thread and len(list(hb_graph.successors(last_op))) == 0: if last_op_main_thread not in hb_graph.predecessors(last_op): hb_graph.add_edge(last_op, last_op_main_thread) else: - warnings.warn(UnusualProbeLog( + warnings.warn(ptypes.UnusualProbeLog( f"I want to add an edge from last op of {tid} to main thread {pid}, but that would create a cycle;" f"the last op of {pid} is likely the clone that creates {tid}" )) def label_nodes(probe_log: ProbeLog, hb_graph: HbGraph, add_op_no: bool = False) -> None: - for node, data in tqdm.tqdm(hb_graph.nodes(data=True), "HBG label"): - op = probe_log.get_op(*node.op_quad()) + for node, data in hb_graph.nodes(data=True): + op = probe_log.get_op(node) + data.setdefault("label", "") + data["cluster"] = str(node.pid) + if add_op_no: + data["label"] += f"{node.op_no}: " if len(list(hb_graph.predecessors(node))) == 0: - data["label"] = "root" + data["label"] += "root" elif isinstance(op.data, InitExecEpochOp): - data["label"] = f"PID {node.pid} exec {node.exec_no}" + data["label"] += f"PID {node.pid} exec {node.exec_no}" elif isinstance(op.data, InitThreadOp): - data["label"] = f"TID {node.tid}" + data["label"] += f"TID {node.tid}" elif isinstance(op.data, ExecOp): - data["label"] = textwrap.fill( + data["label"] += textwrap.fill( "exec " + textwrap.shorten( shlex.join([ textwrap.shorten( @@ -306,22 +303,25 @@ def label_nodes(probe_log: ProbeLog, hb_graph: HbGraph, add_op_no: bool = False) ) elif isinstance(op.data, OpenOp): access = {os.O_RDONLY: "readable", os.O_WRONLY: "writable", os.O_RDWR: "read/writable"}[op.data.flags & os.O_ACCMODE] - data["label"] = f"Open ({access}) {op.data.path.path.decode(errors='backslashreplace')}" - data["label"] += f" fd={op.data.fd}" + data["label"] += f"Open ({access}) fd={op.data.fd}" + data["label"] += f"\n{InodeVersion.from_probe_path(op.data.path).inode!s}" + data["label"] += f"\n{op.data.path.path.decode()}" + elif isinstance(op.data, StatOp): + data["label"] += "Stat" data["label"] += f"\n{InodeVersion.from_probe_path(op.data.path).inode!s}" data["label"] += f"\n{op.data.path.path.decode()}" elif isinstance(op.data, CloseOp): - data["label"] = f"Close fd={op.data.fd}" + data["label"] += f"Close fd={op.data.fd}" + data["label"] += f"\n{InodeVersion.from_probe_path(op.data.path).inode!s}" + data["label"] += f"\n{op.data.path.path.decode()}" elif isinstance(op.data, DupOp): - data["label"] = f"DupOp fd={op.data.old} → fd={op.data.new}" + data["label"] += f"DupOp fd={op.data.old} → fd={op.data.new}" else: - data["label"] = f"{op.data.__class__.__name__}" + data["label"] += f"{op.data.__class__.__name__}" data["labelfontsize"] = 8 if getattr(op.data, "ferrno", 0) != 0: data["label"] += " (failed)" data["color"] = "red" - if add_op_no: - data["label"] = f"{node.op_no}: " + data["label"] for node0, node1, data in hb_graph.edges(data=True): if node0.pid != node1.pid or node0.tid != node1.tid: @@ -331,4 +331,6 @@ def label_nodes(probe_log: ProbeLog, hb_graph: HbGraph, add_op_no: bool = False) cycle = list(networkx.find_cycle(hb_graph)) for a, b in cycle: hb_graph.get_edge_data(a, b)["color"] = "red" - warnings.warn(UnusualProbeLog("Hb graph has cycle (shown in red)")) + warnings.warn(ptypes.UnusualProbeLog( + "Cycle shown in red", + )) diff --git a/probe_py/probe_py/hb_graph_accesses.py b/probe_py/probe_py/hb_graph_accesses.py index 3a3a9237..734bf137 100644 --- a/probe_py/probe_py/hb_graph_accesses.py +++ b/probe_py/probe_py/hb_graph_accesses.py @@ -1,147 +1,100 @@ import collections import dataclasses -import enum import os import pathlib import warnings -import networkx -import tqdm -from . import hb_graph +from . import graph_utils from . import ops from . import ptypes -class AccessMode(enum.IntEnum): - """In what way are we accessing the inode version?""" - EXEC = enum.auto() - DLOPEN = enum.auto() - READ = enum.auto() - WRITE = enum.auto() - READ_WRITE = enum.auto() - TRUNCATE_WRITE = enum.auto() - - def is_side_effect_free(self) -> bool: - return self in {AccessMode.EXEC, AccessMode.DLOPEN, AccessMode.READ} - - @staticmethod - def from_open_flags(flags: int) -> "AccessMode": - access_mode = flags & os.O_ACCMODE - if access_mode == os.O_RDONLY: - return AccessMode.READ - elif flags & (os.O_TRUNC | os.O_CREAT): - return AccessMode.TRUNCATE_WRITE - elif access_mode == os.O_WRONLY: - return AccessMode.WRITE - elif access_mode == os.O_RDWR: - return AccessMode.READ_WRITE - else: - raise ptypes.InvalidProbeLog(f"Invalid open flags: 0x{flags:x}") - - -class Phase(enum.StrEnum): - BEGIN = enum.auto() - END = enum.auto() - - -@dataclasses.dataclass -class Access: - phase: Phase - mode: AccessMode - inode: ptypes.Inode - path: pathlib.Path - op_node: hb_graph.OpNode - fd: int | None - - def hb_graph_to_accesses( probe_log: ptypes.ProbeLog, - hbg: hb_graph.HbGraph, -) -> collections.abc.Iterator[Access | hb_graph.OpNode]: + hbg: ptypes.HbGraph, +) -> collections.abc.Iterator[ptypes.Access | ptypes.OpQuad]: """Reduces a happens-before graph to an ordered list of accesses in one possible schedule.""" @dataclasses.dataclass class FileDescriptor2: - mode: AccessMode + mode: ptypes.AccessMode inode: ptypes.Inode path: pathlib.Path cloexec: bool proc_fd_to_fd = collections.defaultdict[ptypes.Pid, dict[int, FileDescriptor2]](dict) - def close(fd: int, node: hb_graph.OpNode) -> collections.abc.Iterator[Access]: + def close(fd: int, node: ptypes.OpQuad) -> collections.abc.Iterator[ptypes.Access]: if file_desc := proc_fd_to_fd[node.pid].get(fd): file_desc = proc_fd_to_fd[node.pid][fd] - yield Access(Phase.END, file_desc.mode, file_desc.inode, file_desc.path, node, fd) + yield ptypes.Access(ptypes.Phase.END, file_desc.mode, file_desc.inode, file_desc.path, node, fd) del proc_fd_to_fd[node.pid][fd] else: warnings.warn(ptypes.UnusualProbeLog( - f"{node} successfully closed an FD {fd} we never traced. This could come from pipe or pipe2." + f"{node} successfully closed an FD {fd} we never traced.", )) def openfd( fd: int, - mode: AccessMode, + mode: ptypes.AccessMode, cloexec: bool, - node: hb_graph.OpNode, + node: ptypes.OpQuad, path: ops.Path, - ) -> collections.abc.Iterator[Access]: + ) -> collections.abc.Iterator[ptypes.Access]: inode = ptypes.InodeVersion.from_probe_path(path).inode if fd in proc_fd_to_fd[node.pid]: warnings.warn(ptypes.UnusualProbeLog( - f"Prior to {node}, process closed FD {fd} without our knowledge." + f"FD {fd} closed was without our knowledge before {node}.", )) yield from close(fd, node) parsed_path = pathlib.Path(path.path.decode()) proc_fd_to_fd[node.pid][fd] = FileDescriptor2(mode, inode, parsed_path, cloexec) - yield Access(Phase.BEGIN, mode, inode, parsed_path, node, fd) + yield ptypes.Access(ptypes.Phase.BEGIN, mode, inode, parsed_path, node, fd) - interesting_op_types = (ops.OpenOp, ops.CloseOp, ops.DupOp, ops.ExecOp, ops.SpawnOp, ops.InitExecEpochOp, ops.CloneOp) - reduced_hb_graph = hb_graph.retain_only( - probe_log, - hbg, - lambda node, op: isinstance(op.data, interesting_op_types) and getattr(op.data, "ferrno", 0) == 0, - ) - - root_pid = probe_log.get_root_pid() - for node in tqdm.tqdm( - networkx.topological_sort(reduced_hb_graph), - total=len(reduced_hb_graph), - desc="Finding DFG", + for node in graph_utils.topological_sort_depth_first( + hbg, + score_children=lambda parent, child: 0 if parent.pid == child.pid else 1 if parent.pid < child.pid else 2, ): yield node - op_data = probe_log.get_op(*node.op_quad()).data + op_data = probe_log.get_op(node).data match op_data: case ops.InitExecEpochOp(): - if node.exec_no == ptypes.initial_exec_no and node.pid == root_pid: - yield from openfd(0, AccessMode.READ, False, node, op_data.std_in) - yield from openfd(1, AccessMode.TRUNCATE_WRITE, False, node, op_data.std_out) - yield from openfd(2, AccessMode.TRUNCATE_WRITE, False, node, op_data.std_err) + if 0 not in proc_fd_to_fd[node.pid]: + yield from openfd(0, ptypes.AccessMode.READ, False, node, op_data.std_in) + if 1 not in proc_fd_to_fd[node.pid]: + yield from openfd(1, ptypes.AccessMode.TRUNCATE_WRITE, False, node, op_data.std_out) + if 2 not in proc_fd_to_fd[node.pid]: + yield from openfd(2, ptypes.AccessMode.TRUNCATE_WRITE, False, node, op_data.std_err) case ops.OpenOp(): - mode = AccessMode.from_open_flags(op_data.flags) - cloexec = bool(op_data.flags & os.O_CLOEXEC) - yield from openfd(op_data.fd, mode, cloexec, node, op_data.path) + if op_data.ferrno == 0: + mode = ptypes.AccessMode.from_open_flags(op_data.flags) + cloexec = bool(op_data.flags & os.O_CLOEXEC) + yield from openfd(op_data.fd, mode, cloexec, node, op_data.path) case ops.ExecOp(): - for fd, file_desc in list(proc_fd_to_fd[node.pid].items()): - if file_desc.cloexec: - yield from close(fd, node) - exe_inode = ptypes.InodeVersion.from_probe_path(op_data.path).inode - exe_path = pathlib.Path(op_data.path.path.decode()) - yield Access(Phase.BEGIN, AccessMode.EXEC, exe_inode, exe_path, node, None) - yield Access(Phase.END, AccessMode.EXEC, exe_inode, exe_path, node, None) + if op_data.ferrno == 0: + for fd, file_desc in list(proc_fd_to_fd[node.pid].items()): + if file_desc.cloexec: + yield from close(fd, node) + exe_inode = ptypes.InodeVersion.from_probe_path(op_data.path).inode + exe_path = pathlib.Path(op_data.path.path.decode()) + yield ptypes.Access(ptypes.Phase.BEGIN, ptypes.AccessMode.EXEC, exe_inode, exe_path, node, None) + yield ptypes.Access(ptypes.Phase.END, ptypes.AccessMode.EXEC, exe_inode, exe_path, node, None) case ops.CloseOp(): - yield from close(op_data.fd, node) + if op_data.ferrno == 0: + yield from close(op_data.fd, node) case ops.DupOp(): - if old_file_desc := proc_fd_to_fd[node.pid].get(op_data.old): + if op_data.ferrno == 0: # dup2 and dup3 close the new FD, if it was open - if op_data.new in list(proc_fd_to_fd[node.pid]): + # https://www.man7.org/linux/man-pages/man2/dup.2.html + if op_data.new in proc_fd_to_fd[node.pid].keys(): yield from close(op_data.new, node) - proc_fd_to_fd[node.pid][op_data.new] = old_file_desc - else: - warnings.warn(ptypes.UnusualProbeLog( - f"Prior to {node}, process successfully closed an FD {op_data.old} we never traced. This could come from pipe or pipe2." - )) + if old_file_desc := proc_fd_to_fd[node.pid].get(op_data.old): + proc_fd_to_fd[node.pid][op_data.new] = old_file_desc + else: + warnings.warn(ptypes.UnusualProbeLog( + f"{node} successfully duped an FD {op_data.old} (-> {op_data.new}) we never traced.", + )) case ops.CloneOp(): - if op_data.task_type == ptypes.TaskType.TASK_PID and not (op_data.flags & os.CLONE_THREAD): + if op_data.ferrno == 0 and op_data.task_type == ptypes.TaskType.TASK_PID and not (op_data.flags & os.CLONE_THREAD): target = ptypes.Pid(op_data.task_id) if op_data.flags & os.CLONE_FILES: proc_fd_to_fd[target] = proc_fd_to_fd[node.pid] @@ -149,23 +102,11 @@ def openfd( proc_fd_to_fd[target] = {**proc_fd_to_fd[node.pid]} is_last_op_in_process = not any( successor.pid == node.pid - for successor in reduced_hb_graph.successors(node) + for successor in hbg.successors(node) ) if is_last_op_in_process: for fd in list(proc_fd_to_fd[node.pid].keys()): yield from close(fd, node) - # for fd in [0, 1, 2]: - # last_op_idx = len(probe_log.processes[root_pid].execs[ptypes.initial_exec_no].threads[root_pid.main_thread()].ops) - 1 - # last_op = hb_graph.OpNode(root_pid, ptypes.initial_exec_no, root_pid.main_thread(), last_op_idx) - # if fd in proc_fd_to_fd.get(root_pid, {}): - # close(fd, last_op) - for pid, fd_table in proc_fd_to_fd.items(): assert not fd_table, f"somehow we still have open file descriptors at the end. {pid} {fd_table}" - - -def verify_access_list( - accesses_and_nodes: list[Access | hb_graph.OpNode] -) -> None: - pass diff --git a/probe_py/probe_py/ptypes.py b/probe_py/probe_py/ptypes.py index 5fa7298f..6247d3d4 100644 --- a/probe_py/probe_py/ptypes.py +++ b/probe_py/probe_py/ptypes.py @@ -9,6 +9,7 @@ import socket import stat import typing +import networkx import numpy from . import ops from . import consts @@ -182,6 +183,38 @@ class Process: execs: typing.Mapping[ExecNo, Exec] +@dataclasses.dataclass(frozen=True) +class OpQuad: + pid: Pid + exec_no: ExecNo + tid: Tid + op_no: int + + def thread_triple(self) -> tuple[Pid, ExecNo, Tid]: + return (self.pid, self.exec_no, self.tid) + + def __str__(self) -> str: + return f"PID {self.pid} Exec {self.exec_no} TID {self.tid} op {self.op_no}" + + +@dataclasses.dataclass(frozen=True) +class OpQuint(OpQuad): + deduplicator: int + + def deduplicate(self, other: OpQuad) -> OpQuint: + if self.quad() != other: + return OpQuint.from_quad(other, 0) + else: + return OpQuint.from_quad(other, self.deduplicator + 1) + + @staticmethod + def from_quad(quad: OpQuad, deduplicator: int = 0) -> OpQuint: + return OpQuint(quad.pid, quad.exec_no, quad.tid, quad.op_no, deduplicator) + + def quad(self) -> OpQuad: + return OpQuad(self.pid, self.exec_no, self.tid, self.op_no) + + @dataclasses.dataclass(frozen=True) class ProbeLog: processes: typing.Mapping[Pid, Process] @@ -193,34 +226,34 @@ class ProbeLog: # I think we should have probe_log.ops[quad] and probe_log.ops -> iterator # Maybe drop probe_log.ops -> iterator - def get_op(self, pid: Pid, exec_no: ExecNo, tid: Tid, op_no: int) -> ops.Op: - return self.processes[pid].execs[exec_no].threads[tid].ops[op_no] + def get_op(self, op: OpQuad) -> ops.Op: + return self.processes[op.pid].execs[op.exec_no].threads[op.tid].ops[op.op_no] - def ops(self) -> typing.Iterator[tuple[Pid, ExecNo, Tid, int, ops.Op]]: + def ops(self) -> typing.Iterator[tuple[OpQuad, ops.Op]]: for pid, process in sorted(self.processes.items()): for epoch, exec in sorted(process.execs.items()): for tid, thread in sorted(exec.threads.items()): for op_no, op in enumerate(thread.ops): - yield pid, epoch, tid, op_no, op + yield OpQuad(pid, epoch, tid, op_no), op def get_root_pid(self) -> Pid: - for pid, _, _, _, op in self.ops(): + for quad, op in self.ops(): match op.data: case ops.InitExecEpochOp(): if op.data.parent_pid == self.probe_options.parent_of_root: - return Pid(pid) + return Pid(quad.pid) raise RuntimeError("No root process found") def get_parent_pid_map(self) -> typing.Mapping[Pid, Pid]: parent_pid_map = dict[Pid, Pid]() - for pid, _, _, _, op in self.ops(): + for quad, op in self.ops(): match op.data: case ops.CloneOp(): if op.data.ferrno == 0 and op.data.task_type == TaskType.TASK_PID: - parent_pid_map[Pid(op.data.task_id)] = pid + parent_pid_map[Pid(op.data.task_id)] = quad.pid case ops.SpawnOp(): if op.data.ferrno == 0: - parent_pid_map[Pid(op.data.child_pid)] = pid + parent_pid_map[Pid(op.data.child_pid)] = quad.pid return parent_pid_map def n_ops(self) -> int: @@ -246,3 +279,51 @@ class InvalidProbeLog(Exception): class UnusualProbeLog(Warning): pass + + +class AccessMode(enum.IntEnum): + """In what way are we accessing the inode version?""" + EXEC = enum.auto() + DLOPEN = enum.auto() + READ = enum.auto() + WRITE = enum.auto() + READ_WRITE = enum.auto() + TRUNCATE_WRITE = enum.auto() + + @property + def is_side_effect_free(self) -> bool: + return self in {AccessMode.EXEC, AccessMode.DLOPEN, AccessMode.READ} + + @staticmethod + def from_open_flags(flags: int) -> "AccessMode": + access_mode = flags & os.O_ACCMODE + if access_mode == os.O_RDONLY: + return AccessMode.READ + elif flags & (os.O_TRUNC | os.O_CREAT): + return AccessMode.TRUNCATE_WRITE + elif access_mode == os.O_WRONLY: + return AccessMode.WRITE + elif access_mode == os.O_RDWR: + return AccessMode.READ_WRITE + else: + raise InvalidProbeLog(f"Invalid open flags: 0x{flags:x}") + + +class Phase(enum.StrEnum): + BEGIN = enum.auto() + END = enum.auto() + + +@dataclasses.dataclass +class Access: + phase: Phase + mode: AccessMode + inode: Inode + path: pathlib.Path + op_node: OpQuad + fd: int | None + +if typing.TYPE_CHECKING: + HbGraph: typing.TypeAlias = networkx.DiGraph[OpQuad] +else: + HbGraph = networkx.DiGraph diff --git a/probe_py/probe_py/util.py b/probe_py/probe_py/util.py index b4945309..c31757ce 100644 --- a/probe_py/probe_py/util.py +++ b/probe_py/probe_py/util.py @@ -1,6 +1,8 @@ +import abc import collections import getpass import grp +import heapq import itertools import os import pathlib @@ -81,3 +83,79 @@ def decode_nested_object( return obj else: raise TypeError(f"{type(obj)}: {obj}") + + +class Comparable(typing.Protocol): + """Protocol for annotating comparable types.""" + + @abc.abstractmethod + def __lt__(self, other: typing.Self, /) -> bool: + ... + + +_Priority = typing.TypeVar("_Priority", bound=Comparable) +_Task = typing.TypeVar("_Task", bound=collections.abc.Hashable) + + +class PriorityQueue(typing.Generic[_Task, _Priority]): + """Minimum-priority queue + + Use getitem and getitem to view and change a task's priority. + + Get/set priority implies an additional constraint that each task can only be + in the queue once, and also the tasks should be hashable. + + If the priorities are equal, order of extraction is order of insertion. + + https://docs.python.org/3/library/heapq.html#priority-queue-implementation-notes + + """ + + _heap: list[tuple[_Priority, int, _Task]] + _priorities: dict[_Task, tuple[_Priority, int]] + _removed: set[int] + _counter: int = 0 + + def __init__(self, initial: typing.Iterable[tuple[_Task, _Priority]] = []) -> None: + self._heap = [] + self._priorities = {} + self._removed = set() + for task, priority in initial: + if task in self._priorities: + raise RuntimeError(f"{task} is in the initial queue twice") + else: + self._heap.append((priority, self._counter, task)) + self._priorities[task] = (priority, self._counter) + self._counter += 1 + heapq.heapify(self._heap) + + def add(self, task: _Task, priority: _Priority) -> None: + if task in self._priorities: + raise RuntimeError(f"{task} is already in priority queue") + else: + self._priorities[task] = (priority, self._counter) + heapq.heappush(self._heap, (priority, self._counter, task)) + self._counter += 1 + + def pop(self) -> tuple[_Priority, _Task]: + counter = None + while counter is None or counter in self._removed: + priority, counter, task = heapq.heappop(self._heap) + return priority, task + + def __bool__(self) -> bool: + while self._heap and self._heap[0][1] in self._removed: + heapq.heappop(self._heap) + return bool(self._heap) + + def __delitem__(self, task: _Task) -> None: + _, counter = self._priorities[task] + del self._priorities[task] + self._removed.add(counter) + + def __getitem__(self, task: _Task) -> _Priority: + return self._priorities[task][0] + + def __setitem__(self, task: _Task, priority: _Priority) -> None: + del self[task] + self.add(task, priority)