Skip to content

Commit 7b72ffe

Browse files
committed
Fix types
1 parent 1ef0321 commit 7b72ffe

5 files changed

Lines changed: 23 additions & 21 deletions

File tree

probe_py/probe_py/analysis.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import collections
22
from .ptypes import ProbeLog, HbGraph, OpQuad
33
from .ops import CloneOp, WaitOp
4-
from .hb_graph import HbGraph
54

65

76
def get_max_parallelism_latest(hb_graph: HbGraph, probe_log: ProbeLog) -> int:

probe_py/probe_py/dataflow_graph.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import typing
99
import warnings
1010
import networkx
11+
import tqdm
1112
from . import graph_utils
1213
from .hb_graph_accesses import hb_graph_to_accesses
1314
from . import ops
@@ -84,26 +85,26 @@ def ensure_state(quad: ptypes.OpQuad, desired_state: PidState) -> ptypes.OpQuint
8485
inode_to_paths[access.inode].add(access.path)
8586
version = InodeVersionNode(access.inode, version_num)
8687
next_version = InodeVersionNode(access.inode, version_num + 1)
87-
ensure_state(access.op_node, PidState.READING if access.mode.is_side_effect_free() else PidState.WRITING)
88+
ensure_state(access.op_node, PidState.READING if access.mode.is_side_effect_free else PidState.WRITING)
8889
if (op_node := last_op_in_process.get(access.op_node.pid)) is None:
8990
warnings.warn(ptypes.UnusualProbeLog(f"Can't find last node from process {access.op_node.pid}"))
9091
continue
9192
match access.mode:
92-
case AccessMode.WRITE:
93-
if access.phase == Phase.BEGIN:
93+
case ptypes.AccessMode.WRITE:
94+
if access.phase == ptypes.Phase.BEGIN:
9495
dataflow_graph.add_edge(op_node, next_version)
9596
dataflow_graph.add_edge(version, next_version)
96-
case AccessMode.TRUNCATE_WRITE:
97-
if access.phase == Phase.END:
97+
case ptypes.AccessMode.TRUNCATE_WRITE:
98+
if access.phase == ptypes.Phase.END:
9899
dataflow_graph.add_edge(op_node, next_version)
99-
case AccessMode.READ_WRITE:
100-
if access.phase == Phase.BEGIN:
100+
case ptypes.AccessMode.READ_WRITE:
101+
if access.phase == ptypes.Phase.BEGIN:
101102
dataflow_graph.add_edge(version, op_node)
102-
if access.phase == Phase.END:
103+
if access.phase == ptypes.Phase.END:
103104
dataflow_graph.add_edge(op_node, next_version)
104105
dataflow_graph.add_edge(version, next_version)
105-
case AccessMode.READ | AccessMode.EXEC | AccessMode.DLOPEN:
106-
if access.phase == Phase.BEGIN:
106+
case ptypes.AccessMode.READ | ptypes.AccessMode.EXEC | ptypes.AccessMode.DLOPEN:
107+
if access.phase == ptypes.Phase.BEGIN:
107108
dataflow_graph.add_edge(version, op_node)
108109
case _:
109110
raise TypeError()
@@ -233,9 +234,9 @@ def label_nodes(
233234
):
234235
data = dataflow_graph.nodes(data=True)[node]
235236
match node:
236-
case hb_graph.OpNode():
237+
case ptypes.OpQuad():
237238
data["shape"] = "oval"
238-
op = probe_log.get_op(node.pid, node.exec_no, node.pid.main_thread(), 0)
239+
op = probe_log.get_op(ptypes.OpQuad(node.pid, node.exec_no, node.pid.main_thread(), 0))
239240
if node.op_no == 0:
240241
count[(node.pid, node.exec_no)] = 1
241242
if node.exec_no != 0:

probe_py/probe_py/graph_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ def map_nodes(
8686
) -> networkx.DiGraph[_Node2]:
8787
dct = {node: function(node) for node in graph.nodes()}
8888
assert util.all_unique(dct.values()), util.duplicates(dct.values())
89-
ret = networkx.relabel_nodes(graph, dct)
90-
return typing.cast(networkx.DiGraph[_Node2], ret)
89+
return networkx.relabel_nodes(graph, dct)
9190

9291

9392
def serialize_graph(
@@ -99,7 +98,7 @@ def serialize_graph(
9998
if name_mapper is None:
10099
name_mapper = typing.cast(
101100
typing.Callable[[_Node], str],
102-
lambda node: graph.nodes[node].get("id", str(node)),
101+
lambda node: graph.nodes(data=True)[node].get("id", str(node)),
103102
)
104103
graph2 = map_nodes(name_mapper, graph)
105104
pydot_graph = networkx.drawing.nx_pydot.to_pydot(graph2)
@@ -376,7 +375,9 @@ def add_edge(self, source: _Node, target: _Node) -> None:
376375
self.dag_tc.add_edge(descendant_of_source, descendant_of_target)
377376

378377

379-
def get_faces(planar_graph: networkx.PlanarEmbedding[_Node]) -> frozenset[tuple[_Node, ...]]:
378+
def get_faces(
379+
planar_graph: networkx.PlanarEmbedding[_Node], # type: ignore
380+
) -> frozenset[tuple[_Node, ...]]:
380381
faces = set()
381382
covered_half_edges = set()
382383
for half_edge in planar_graph.edges():
@@ -475,7 +476,7 @@ def topological_sort_depth_first(
475476
) -> typing.Iterable[_Node]:
476477
"""Topological sort that breaks ties by depth first, and then by lowest child score."""
477478
queue = util.PriorityQueue[_Node, tuple[int, int]](
478-
(node, (typing.cast(int, dag.in_degree(node)), 0))
479+
(node, (dag.in_degree(node), 0))
479480
for node in dag.nodes()
480481
)
481482
counter = 0

probe_py/probe_py/hb_graph_accesses.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ def openfd(
5959
match op_data:
6060
case ops.InitExecEpochOp():
6161
if 0 not in proc_fd_to_fd[node.pid]:
62-
yield from openfd(0, ptypes.AccessMode.READ, False, node, op_data.stdin)
62+
yield from openfd(0, ptypes.AccessMode.READ, False, node, op_data.std_in)
6363
if 1 not in proc_fd_to_fd[node.pid]:
64-
yield from openfd(1, ptypes.AccessMode.TRUNCATE_WRITE, False, node, op_data.stdout)
64+
yield from openfd(1, ptypes.AccessMode.TRUNCATE_WRITE, False, node, op_data.std_out)
6565
if 2 not in proc_fd_to_fd[node.pid]:
66-
yield from openfd(2, ptypes.AccessMode.TRUNCATE_WRITE, False, node, op_data.stderr)
66+
yield from openfd(2, ptypes.AccessMode.TRUNCATE_WRITE, False, node, op_data.std_err)
6767
case ops.OpenOp():
6868
if op_data.ferrno == 0:
6969
mode = ptypes.AccessMode.from_open_flags(op_data.flags)

probe_py/probe_py/ptypes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import socket
1010
import stat
1111
import typing
12+
import networkx
1213
import numpy
1314
from . import ops
1415
from . import consts

0 commit comments

Comments
 (0)