1
1
import datetime
2
2
import itertools
3
3
import json
4
+ import pickle
4
5
from dataclasses import dataclass
5
6
from pathlib import Path
6
7
from typing import Any , Callable , Container , Literal
18
19
from lib .engines .torch .settings import Compilation , TorchModuleSettings , TorchReduceMethod
19
20
from lib .sources .neuralogic_settings import NeuralogicSettings
20
21
from lib .utils import dataclass_to_shorthand , iter_empty , serialize_dataclass
22
+ from lib .vectorize .model .op_network import VectorizedOpSeqNetwork
21
23
from lib .vectorize .settings import VectorizeSettings
22
24
from lib .vectorize .settings_presets import VectorizeSettingsPresets , iterate_vectorize_settings_presets
23
25
@@ -245,12 +247,15 @@ def total(dir: Path):
245
247
print (len (variants ))
246
248
247
249
248
- @cli .command ()
250
+ @cli .command (context_settings = { "show_default" : True } )
249
251
@click .argument (
250
252
"dir" , type = click .Path (exists = True , file_okay = False , dir_okay = True , readable = True , writable = True , path_type = Path )
251
253
)
252
254
@click .argument ("index" , type = int )
253
- def run (dir : Path , index : int ):
255
+ @click .option ("--measure/--no-measure" , default = True )
256
+ @click .option ("--save-architecture/--no-save-architecture" , default = True )
257
+ @click .option ("--force-cpu" , default = False , is_flag = True )
258
+ def run (dir : Path , index : int , measure : bool , save_architecture : bool , force_cpu : bool ):
254
259
torch .set_default_dtype (torch .float32 )
255
260
256
261
variants_file = dir / "variants.txt"
@@ -270,13 +275,16 @@ def run(dir: Path, index: int):
270
275
del variants
271
276
272
277
time = datetime .datetime .now ().strftime ("%Y-%m-%d_%H%M%S" )
273
- name = f"{ dataclass_to_shorthand (variant )} ,{ time } "
278
+ notimename = dataclass_to_shorthand (variant )
279
+ timename = f"{ notimename } ,{ time } "
274
280
275
281
print (variant )
276
282
print ()
277
- print (name )
283
+ print (timename )
278
284
print ()
279
285
286
+ device = "cpu" if force_cpu else variant .device
287
+
280
288
match variant :
281
289
case JavaVariant (engine = "java" ):
282
290
runnable = NeuraLogicCPURunnable ()
@@ -285,7 +293,7 @@ def run(dir: Path, index: int):
285
293
)
286
294
case VectorizedTorchVariant (engine = "torch" ):
287
295
runnable = NeuralogicVectorizedTorchRunnable (
288
- device = variant . device ,
296
+ device = device ,
289
297
neuralogic_settings = DEFAULT_NEURALOGIC_SETTINGS_VECTORIZE ,
290
298
vectorize_settings = variant .settings ,
291
299
torch_settings = TorchModuleSettings (
@@ -295,31 +303,81 @@ def run(dir: Path, index: int):
295
303
)
296
304
dataset = DATASET_OPTIONS [variant .dataset ][0 ](DEFAULT_NEURALOGIC_SETTINGS_VECTORIZE )
297
305
case TorchVariant (engine = "pyg" ):
298
- runnable = PytorchGeometricRunnable (device = variant . device )
306
+ runnable = PytorchGeometricRunnable (device = device )
299
307
dataset = DATASET_OPTIONS [variant .dataset ][0 ](DEFAULT_NEURALOGIC_SETTINGS_VECTORIZE )
300
308
case _:
301
309
raise ValueError (variant )
302
310
303
311
dataset = dataset .build ()
304
312
305
- out_file = dir / f"{ name } .json"
306
- out = variant .serialize ()
307
-
308
- if variant .backward :
309
- fwd , bwd , cmb = measure_backward (runnable , dataset , times = variant .times )
310
- out ["fwd" ] = fwd .times_ns .tolist ()
311
- out ["bwd" ] = bwd .times_ns .tolist ()
312
- out ["cmb" ] = cmb .times_ns .tolist ()
313
- print ("Forward: " , fwd )
314
- print ("Backward:" , bwd )
315
- print ("Combined:" , cmb )
316
- else :
317
- cmb = measure_forward (runnable , dataset , times = variant .times )
318
- out ["cmb" ] = out ["fwd" ] = cmb .times_ns .tolist ()
319
- print ("Forward:" , cmb )
320
-
321
- with open (out_file , "w" ) as fp :
322
- json .dump (out , fp )
313
+ if measure :
314
+ out_file = dir / f"{ timename } .json"
315
+ out = variant .serialize ()
316
+
317
+ if variant .backward :
318
+ fwd , bwd , cmb = measure_backward (runnable , dataset , times = variant .times )
319
+ out ["fwd" ] = fwd .times_ns .tolist ()
320
+ out ["bwd" ] = bwd .times_ns .tolist ()
321
+ out ["cmb" ] = cmb .times_ns .tolist ()
322
+ print ("Forward: " , fwd )
323
+ print ("Backward:" , bwd )
324
+ print ("Combined:" , cmb )
325
+ else :
326
+ cmb = measure_forward (runnable , dataset , times = variant .times )
327
+ out ["cmb" ] = out ["fwd" ] = cmb .times_ns .tolist ()
328
+ print ("Forward:" , cmb )
329
+
330
+ with open (out_file , "w" ) as fp :
331
+ json .dump (out , fp )
332
+ elif save_architecture :
333
+ runnable .initialize (dataset )
334
+
335
+ if save_architecture and hasattr (runnable , "vectorized_network" ):
336
+ out_pkl_file = dir / f"{ notimename } .pkl"
337
+ print (out_pkl_file )
338
+ # TODO
339
+ with open (out_pkl_file , "wb" ) as fp :
340
+ pickle .dump (runnable .vectorized_network , fp )
341
+
342
+
343
+ @cli .command (context_settings = {"show_default" : True })
344
+ @click .argument (
345
+ "dir" , type = click .Path (exists = True , file_okay = False , dir_okay = True , readable = True , writable = True , path_type = Path )
346
+ )
347
+ def build_architecture_map (dir : Path ):
348
+ variants_file = dir / "variants.txt"
349
+
350
+ if not variants_file .exists ():
351
+ raise click .ClickException (f"{ variants_file .absolute ()} does not exist." )
352
+
353
+ with open (variants_file , "r" ) as fp :
354
+ variants = json .load (fp )
355
+
356
+ architectures : list [VectorizedOpSeqNetwork ] = []
357
+ architectures_dict : dict [VectorizedOpSeqNetwork , int ] = {}
358
+ variants_dict : dict [str , int ] = {}
359
+
360
+ for v in variants :
361
+ variant = Variant .deserialize (v )
362
+ notimename = dataclass_to_shorthand (variant )
363
+ pkl_file_path = dir / (notimename + ".pkl" )
364
+ if pkl_file_path .exists ():
365
+ with open (pkl_file_path , "rb" ) as fp :
366
+ vectorized_network = pickle .load (fp )
367
+
368
+ if vectorized_network in architectures_dict :
369
+ idx = architectures_dict [vectorized_network ]
370
+ else :
371
+ idx = len (architectures )
372
+ architectures .append (vectorized_network )
373
+ architectures_dict [vectorized_network ] = idx
374
+
375
+ variants_dict [notimename ] = idx
376
+
377
+ with open (dir / "networks.pkl" , "wb" ) as fp :
378
+ pickle .dump ((architectures , variants_dict ), fp )
379
+
380
+ print (len (architectures ))
323
381
324
382
325
383
if __name__ == "__main__" :
0 commit comments