Skip to content

Commit 4b1a2d0

Browse files
author
Neumann, Jan
committed
benchmark: store architectures
1 parent c2958df commit 4b1a2d0

File tree

2 files changed

+83
-24
lines changed

2 files changed

+83
-24
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ __pycache__/
77
/trace_*.json
88

99
/datasets/
10+
runs/

benchmark.py

+82-24
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22
import itertools
33
import json
4+
import pickle
45
from dataclasses import dataclass
56
from pathlib import Path
67
from typing import Any, Callable, Container, Literal
@@ -18,6 +19,7 @@
1819
from lib.engines.torch.settings import Compilation, TorchModuleSettings, TorchReduceMethod
1920
from lib.sources.neuralogic_settings import NeuralogicSettings
2021
from lib.utils import dataclass_to_shorthand, iter_empty, serialize_dataclass
22+
from lib.vectorize.model.op_network import VectorizedOpSeqNetwork
2123
from lib.vectorize.settings import VectorizeSettings
2224
from lib.vectorize.settings_presets import VectorizeSettingsPresets, iterate_vectorize_settings_presets
2325

@@ -245,12 +247,15 @@ def total(dir: Path):
245247
print(len(variants))
246248

247249

248-
@cli.command()
250+
@cli.command(context_settings={"show_default": True})
249251
@click.argument(
250252
"dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, readable=True, writable=True, path_type=Path)
251253
)
252254
@click.argument("index", type=int)
253-
def run(dir: Path, index: int):
255+
@click.option("--measure/--no-measure", default=True)
256+
@click.option("--save-architecture/--no-save-architecture", default=True)
257+
@click.option("--force-cpu", default=False, is_flag=True)
258+
def run(dir: Path, index: int, measure: bool, save_architecture: bool, force_cpu: bool):
254259
torch.set_default_dtype(torch.float32)
255260

256261
variants_file = dir / "variants.txt"
@@ -270,13 +275,16 @@ def run(dir: Path, index: int):
270275
del variants
271276

272277
time = datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")
273-
name = f"{dataclass_to_shorthand(variant)},{time}"
278+
notimename = dataclass_to_shorthand(variant)
279+
timename = f"{notimename},{time}"
274280

275281
print(variant)
276282
print()
277-
print(name)
283+
print(timename)
278284
print()
279285

286+
device = "cpu" if force_cpu else variant.device
287+
280288
match variant:
281289
case JavaVariant(engine="java"):
282290
runnable = NeuraLogicCPURunnable()
@@ -285,7 +293,7 @@ def run(dir: Path, index: int):
285293
)
286294
case VectorizedTorchVariant(engine="torch"):
287295
runnable = NeuralogicVectorizedTorchRunnable(
288-
device=variant.device,
296+
device=device,
289297
neuralogic_settings=DEFAULT_NEURALOGIC_SETTINGS_VECTORIZE,
290298
vectorize_settings=variant.settings,
291299
torch_settings=TorchModuleSettings(
@@ -295,31 +303,81 @@ def run(dir: Path, index: int):
295303
)
296304
dataset = DATASET_OPTIONS[variant.dataset][0](DEFAULT_NEURALOGIC_SETTINGS_VECTORIZE)
297305
case TorchVariant(engine="pyg"):
298-
runnable = PytorchGeometricRunnable(device=variant.device)
306+
runnable = PytorchGeometricRunnable(device=device)
299307
dataset = DATASET_OPTIONS[variant.dataset][0](DEFAULT_NEURALOGIC_SETTINGS_VECTORIZE)
300308
case _:
301309
raise ValueError(variant)
302310

303311
dataset = dataset.build()
304312

305-
out_file = dir / f"{name}.json"
306-
out = variant.serialize()
307-
308-
if variant.backward:
309-
fwd, bwd, cmb = measure_backward(runnable, dataset, times=variant.times)
310-
out["fwd"] = fwd.times_ns.tolist()
311-
out["bwd"] = bwd.times_ns.tolist()
312-
out["cmb"] = cmb.times_ns.tolist()
313-
print("Forward: ", fwd)
314-
print("Backward:", bwd)
315-
print("Combined:", cmb)
316-
else:
317-
cmb = measure_forward(runnable, dataset, times=variant.times)
318-
out["cmb"] = out["fwd"] = cmb.times_ns.tolist()
319-
print("Forward:", cmb)
320-
321-
with open(out_file, "w") as fp:
322-
json.dump(out, fp)
313+
if measure:
314+
out_file = dir / f"{timename}.json"
315+
out = variant.serialize()
316+
317+
if variant.backward:
318+
fwd, bwd, cmb = measure_backward(runnable, dataset, times=variant.times)
319+
out["fwd"] = fwd.times_ns.tolist()
320+
out["bwd"] = bwd.times_ns.tolist()
321+
out["cmb"] = cmb.times_ns.tolist()
322+
print("Forward: ", fwd)
323+
print("Backward:", bwd)
324+
print("Combined:", cmb)
325+
else:
326+
cmb = measure_forward(runnable, dataset, times=variant.times)
327+
out["cmb"] = out["fwd"] = cmb.times_ns.tolist()
328+
print("Forward:", cmb)
329+
330+
with open(out_file, "w") as fp:
331+
json.dump(out, fp)
332+
elif save_architecture:
333+
runnable.initialize(dataset)
334+
335+
if save_architecture and hasattr(runnable, "vectorized_network"):
336+
out_pkl_file = dir / f"{notimename}.pkl"
337+
print(out_pkl_file)
338+
# TODO
339+
with open(out_pkl_file, "wb") as fp:
340+
pickle.dump(runnable.vectorized_network, fp)
341+
342+
343+
@cli.command(context_settings={"show_default": True})
344+
@click.argument(
345+
"dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, readable=True, writable=True, path_type=Path)
346+
)
347+
def build_architecture_map(dir: Path):
348+
variants_file = dir / "variants.txt"
349+
350+
if not variants_file.exists():
351+
raise click.ClickException(f"{variants_file.absolute()} does not exist.")
352+
353+
with open(variants_file, "r") as fp:
354+
variants = json.load(fp)
355+
356+
architectures: list[VectorizedOpSeqNetwork] = []
357+
architectures_dict: dict[VectorizedOpSeqNetwork, int] = {}
358+
variants_dict: dict[str, int] = {}
359+
360+
for v in variants:
361+
variant = Variant.deserialize(v)
362+
notimename = dataclass_to_shorthand(variant)
363+
pkl_file_path = dir / (notimename + ".pkl")
364+
if pkl_file_path.exists():
365+
with open(pkl_file_path, "rb") as fp:
366+
vectorized_network = pickle.load(fp)
367+
368+
if vectorized_network in architectures_dict:
369+
idx = architectures_dict[vectorized_network]
370+
else:
371+
idx = len(architectures)
372+
architectures.append(vectorized_network)
373+
architectures_dict[vectorized_network] = idx
374+
375+
variants_dict[notimename] = idx
376+
377+
with open(dir / "networks.pkl", "wb") as fp:
378+
pickle.dump((architectures, variants_dict), fp)
379+
380+
print(len(architectures))
323381

324382

325383
if __name__ == "__main__":

0 commit comments

Comments
 (0)