Skip to content

Commit

Permalink
refine plotting script
Browse files Browse the repository at this point in the history
  • Loading branch information
jdolence committed Nov 7, 2023
1 parent cda4e71 commit b1aa06e
Showing 1 changed file with 82 additions and 30 deletions.
112 changes: 82 additions & 30 deletions scripts/python/packages/parthenon_plot_trace/plot_trace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sys
import re
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
import itertools
from collections import OrderedDict
from argparse import ArgumentParser
Expand All @@ -23,6 +25,14 @@
help="Final step to include (inclusive)",
)

parser.add_argument(
"--outfile",
dest="outfile",
type=str,
default="NOT_SET",
help="To dump the plot to a file, specify the name here",
)

parser.add_argument("files", type=str, nargs="+", help="trace files to plot")


Expand Down Expand Up @@ -118,51 +128,93 @@ def plot_trace(self, ax, colorMap, hatchMap):
)


def main(files, step_start, step_stop):
trace = []
for f in files:
print("Getting trace", f, end="")
trace.append(Trace(f, step_start, step_stop))
print(" done!")
# get max number of functions
all_funcs = OrderedDict()
for t in trace:
for key in t.region_names():
all_funcs[key] = ""

num_colors = len(all_funcs)
def plot_traces(traces, functions, outfile):
num_colors = len(functions)
cm = plt.get_cmap("tab20")
hatch = ["", "--", "/", "\\", "+", "x"]
num_hatches = len(hatch)
colorMap = {}
hatchMap = {}
cindex = 0
for f, dum in all_funcs.items():
for f, dum in functions.items():
colorMap[f] = cm((cindex + 0.5) / num_colors)
hatchMap[f] = hatch[cindex % num_hatches]
cindex += 1
fig, ax = plt.subplots(figsize=(18, 12))
for t in trace:
print("Plotting trace", t.rank, end="")
t.plot_trace(ax, colorMap, hatchMap)
print(" done!")

min_rank = 999999
max_rank = 0
min_time = 999999.0
max_time = 0.0
for f, dum in functions.items():
if f == "StepTimer":
continue
patches = []
for t in traces:
for i in range(len(t.regions[f].start)):
min_rank = min(min_rank, t.rank)
max_rank = max(max_rank, t.rank)
min_time = min(min_time, t.regions[f].start[i])
max_time = max(
max_time, t.regions[f].start[i] + t.regions[f].duration[i]
)
patches.append(
Rectangle(
(t.regions[f].start[i], t.rank - 0.25),
t.regions[f].duration[i],
0.5,
)
)
pc = PatchCollection(
patches, linewidth=0, facecolor=colorMap[f], hatch=hatchMap[f]
)
ax.add_collection(pc)
plt.xlim(min_time, max_time)
plt.ylim(min_rank - 0.5, max_rank + 0.5)
plt.xlabel("Time (s)")
plt.ylabel("Rank")
# box = ax.get_position()
# ax.set_position([box.x0, box.y0 + box.height*0.2, box.width, box.height * 0.8])
handles, labels = plt.gca().get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))
ax.legend(
by_label.values(),
by_label.keys(),
loc="upper center",
bbox_to_anchor=(0, -0.02, 1, -0.02),
ncol=3,
plt.yticks([i for i in range(min_rank, max_rank + 1)])
handles = []
for f, dum in functions.items():
if f == "StepTimer":
continue
handles.append(
Rectangle(
(min_time, min_rank),
0.0,
0.0,
linewidth=0,
edgecolor="k",
facecolor=colorMap[f],
hatch=hatchMap[f],
label=f,
)
)
plt.legend(
loc="upper center", handles=handles, bbox_to_anchor=(0, -0.02, 1, -0.02), ncol=3
)
plt.tight_layout()
plt.show()
if outfile == "NOT_SET":
plt.show()
else:
plt.savefig(outfile, dpi=300)


def main(files, step_start, step_stop, outfile):
trace = []
for f in files:
print("Getting trace", f, end="")
trace.append(Trace(f, step_start, step_stop))
print(" done!")
# get max number of functions
all_funcs = OrderedDict()
for t in trace:
for key in t.region_names():
all_funcs[key] = ""

plot_traces(trace, all_funcs, outfile)


if __name__ == "__main__":
args = parser.parse_args()
main(args.files, args.step_start, args.step_stop)
main(args.files, args.step_start, args.step_stop, args.outfile)

0 comments on commit b1aa06e

Please sign in to comment.