Skip to content

Commit 6dde178

Browse files
scripts: fix compare-llama-bench commit hash logic (ggml-org#11891)
1 parent fc10c38 commit 6dde178

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

scripts/compare-llama-bench.py

+29-16
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,22 @@
124124

125125
connection = sqlite3.connect(input_file)
126126
cursor = connection.cursor()
127+
128+
build_len_min: int = cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
129+
build_len_max: int = cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]
130+
131+
if build_len_min != build_len_max:
132+
logger.warning(f"{input_file} contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
133+
"Try purging the the database of old commits.")
134+
cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {build_len_min});")
135+
136+
build_len: int = build_len_min
137+
127138
builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
139+
builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]
128140

129-
commit_short_len = len(builds[0][0])
141+
if not builds:
142+
raise RuntimeError(f"{input_file} does not contain any builds.")
130143

131144
try:
132145
repo = git.Repo(".", search_parent_directories=True)
@@ -140,11 +153,11 @@ def find_parent_in_data(commit: git.Commit):
140153
seen_hexsha8 = set()
141154
while heap:
142155
depth, current_commit = heapq.heappop(heap)
143-
current_hexsha8 = commit.hexsha[:commit_short_len]
144-
if (current_hexsha8,) in builds:
156+
current_hexsha8 = commit.hexsha[:build_len]
157+
if current_hexsha8 in builds:
145158
return current_hexsha8
146159
for parent in commit.parents:
147-
parent_hexsha8 = parent.hexsha[:commit_short_len]
160+
parent_hexsha8 = parent.hexsha[:build_len]
148161
if parent_hexsha8 not in seen_hexsha8:
149162
seen_hexsha8.add(parent_hexsha8)
150163
heapq.heappush(heap, (depth + 1, parent))
@@ -158,48 +171,48 @@ def get_all_parent_hexsha8s(commit: git.Commit):
158171

159172
while unvisited:
160173
current_commit = unvisited.pop(0)
161-
visited.append(current_commit.hexsha[:commit_short_len])
174+
visited.append(current_commit.hexsha[:build_len])
162175
for parent in current_commit.parents:
163-
if parent.hexsha[:commit_short_len] not in visited:
176+
if parent.hexsha[:build_len] not in visited:
164177
unvisited.append(parent)
165178

166179
return visited
167180

168181

169-
def get_commit_name(hexsha8):
182+
def get_commit_name(hexsha8: str):
170183
"""Helper function to find a human-readable name for a commit if possible."""
171184
if repo is None:
172185
return hexsha8
173186
for h in repo.heads:
174-
if h.commit.hexsha[:commit_short_len] == hexsha8:
187+
if h.commit.hexsha[:build_len] == hexsha8:
175188
return h.name
176189
for t in repo.tags:
177-
if t.commit.hexsha[:commit_short_len] == hexsha8:
190+
if t.commit.hexsha[:build_len] == hexsha8:
178191
return t.name
179192
return hexsha8
180193

181194

182-
def get_commit_hexsha8(name):
195+
def get_commit_hexsha8(name: str):
183196
"""Helper function to search for a commit given a human-readable name."""
184197
if repo is None:
185198
return None
186199
for h in repo.heads:
187200
if h.name == name:
188-
return h.commit.hexsha[:commit_short_len]
201+
return h.commit.hexsha[:build_len]
189202
for t in repo.tags:
190203
if t.name == name:
191-
return t.commit.hexsha[:commit_short_len]
204+
return t.commit.hexsha[:build_len]
192205
for c in repo.iter_commits("--all"):
193-
if c.hexsha[:commit_short_len] == name[:commit_short_len]:
194-
return c.hexsha[:commit_short_len]
206+
if c.hexsha[:build_len] == name[:build_len]:
207+
return c.hexsha[:build_len]
195208
return None
196209

197210

198211
hexsha8_baseline = name_baseline = None
199212

200213
# If the user specified a baseline, try to find a commit for it:
201214
if known_args.baseline is not None:
202-
if (known_args.baseline,) in builds:
215+
if known_args.baseline in builds:
203216
hexsha8_baseline = known_args.baseline
204217
if hexsha8_baseline is None:
205218
hexsha8_baseline = get_commit_hexsha8(known_args.baseline)
@@ -228,7 +241,7 @@ def get_commit_hexsha8(name):
228241

229242
# If the user has specified a compare value, try to find a corresponding commit:
230243
if known_args.compare is not None:
231-
if (known_args.compare,) in builds:
244+
if known_args.compare in builds:
232245
hexsha8_compare = known_args.compare
233246
if hexsha8_compare is None:
234247
hexsha8_compare = get_commit_hexsha8(known_args.compare)

0 commit comments

Comments
 (0)