50
50
"jax" : JAXLinker (),
51
51
"pytorch" : PytorchLinker (),
52
52
"numba" : NumbaLinker (),
53
+ "numba_vm" : NumbaLinker (vm = True ),
53
54
}
54
55
55
56
@@ -63,9 +64,8 @@ def register_linker(name, linker):
63
64
# If a string is passed as the optimizer argument in the constructor
64
65
# for Mode, it will be used as the key to retrieve the real optimizer
65
66
# in this dictionary
66
- exclude = []
67
- if not config .cxx :
68
- exclude = ["cxx_only" ]
67
+
68
+ exclude = ["cxx_only" , "BlasOpt" ]
69
69
OPT_NONE = RewriteDatabaseQuery (include = [], exclude = exclude )
70
70
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
71
71
OPT_MINIMUM = RewriteDatabaseQuery (include = ["minimum_compile" ], exclude = exclude )
@@ -351,6 +351,11 @@ def __setstate__(self, state):
351
351
optimizer = predefined_optimizers [optimizer ]
352
352
if isinstance (optimizer , RewriteDatabaseQuery ):
353
353
self .provided_optimizer = optimizer
354
+
355
+ # Force numba-required rewrites if using NumbaLinker
356
+ if isinstance (linker , NumbaLinker ):
357
+ optimizer = optimizer .including ("numba" )
358
+
354
359
self ._optimizer = optimizer
355
360
self .call_time = 0
356
361
self .fn_time = 0
@@ -448,16 +453,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
448
453
# string as the key
449
454
# Use VM_linker to allow lazy evaluation by default.
450
455
FAST_COMPILE = Mode (
451
- VMLinker (use_cloop = False , c_thunks = False ),
452
- RewriteDatabaseQuery (include = ["fast_compile" , "py_only" ]),
456
+ NumbaLinker (vm = True ),
457
+ # TODO: Fast_compile should just use python code, CHANGE ME!
458
+ RewriteDatabaseQuery (
459
+ include = ["fast_compile" , "numba" ],
460
+ exclude = ["cxx_only" , "BlasOpt" , "local_careduce_fusion" ],
461
+ ),
462
+ )
463
+ FAST_RUN = Mode (
464
+ NumbaLinker (vm = True ),
465
+ RewriteDatabaseQuery (
466
+ include = ["fast_run" , "numba" ],
467
+ exclude = ["cxx_only" , "BlasOpt" , "local_careduce_fusion" ],
468
+ ),
453
469
)
454
- if config .cxx :
455
- FAST_RUN = Mode ("cvm" , "fast_run" )
456
- else :
457
- FAST_RUN = Mode (
458
- "vm" ,
459
- RewriteDatabaseQuery (include = ["fast_run" , "py_only" ]),
460
- )
461
470
462
471
NUMBA = Mode (
463
472
NumbaLinker (),
@@ -580,6 +589,7 @@ def register_mode(name, mode):
580
589
Add a `Mode` which can be referred to by `name` in `function`.
581
590
582
591
"""
592
+ # TODO: Remove me
583
593
if name in predefined_modes :
584
594
raise ValueError (f"Mode name already taken: { name } " )
585
595
predefined_modes [name ] = mode
0 commit comments