Skip to content

Commit e1940f7

Browse files
committed
Make NUMBA_VM the default mode
1 parent d64faca commit e1940f7

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

pytensor/compile/mode.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"jax": JAXLinker(),
5151
"pytorch": PytorchLinker(),
5252
"numba": NumbaLinker(),
53+
"numba_vm": NumbaLinker(vm=True),
5354
}
5455

5556

@@ -63,9 +64,8 @@ def register_linker(name, linker):
6364
# If a string is passed as the optimizer argument in the constructor
6465
# for Mode, it will be used as the key to retrieve the real optimizer
6566
# in this dictionary
66-
exclude = []
67-
if not config.cxx:
68-
exclude = ["cxx_only"]
67+
68+
exclude = ["cxx_only", "BlasOpt"]
6969
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
7070
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
7171
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
@@ -351,6 +351,11 @@ def __setstate__(self, state):
351351
optimizer = predefined_optimizers[optimizer]
352352
if isinstance(optimizer, RewriteDatabaseQuery):
353353
self.provided_optimizer = optimizer
354+
355+
# Force numba-required rewrites if using NumbaLinker
356+
if isinstance(linker, NumbaLinker):
357+
optimizer = optimizer.including("numba")
358+
354359
self._optimizer = optimizer
355360
self.call_time = 0
356361
self.fn_time = 0
@@ -448,16 +453,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
448453
# string as the key
449454
# Use VM_linker to allow lazy evaluation by default.
450455
FAST_COMPILE = Mode(
451-
VMLinker(use_cloop=False, c_thunks=False),
452-
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
456+
NumbaLinker(vm=True),
457+
# TODO: Fast_compile should just use python code, CHANGE ME!
458+
RewriteDatabaseQuery(
459+
include=["fast_compile", "numba"],
460+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
461+
),
462+
)
463+
FAST_RUN = Mode(
464+
NumbaLinker(vm=True),
465+
RewriteDatabaseQuery(
466+
include=["fast_run", "numba"],
467+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
468+
),
453469
)
454-
if config.cxx:
455-
FAST_RUN = Mode("cvm", "fast_run")
456-
else:
457-
FAST_RUN = Mode(
458-
"vm",
459-
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
460-
)
461470

462471
NUMBA = Mode(
463472
NumbaLinker(),
@@ -580,6 +589,7 @@ def register_mode(name, mode):
580589
Add a `Mode` which can be referred to by `name` in `function`.
581590
582591
"""
592+
# TODO: Remove me
583593
if name in predefined_modes:
584594
raise ValueError(f"Mode name already taken: {name}")
585595
predefined_modes[name] = mode

pytensor/configdefaults.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,11 +370,22 @@ def add_compile_configvars():
370370

371371
if rc == 0 and config.cxx != "":
372372
# Keep the default linker the same as the one for the mode FAST_RUN
373-
linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
373+
linker_options = [
374+
"cvm",
375+
"c|py",
376+
"py",
377+
"c",
378+
"c|py_nogc",
379+
"vm",
380+
"vm_nogc",
381+
"cvm_nogc",
382+
"jax",
383+
"numba",
384+
]
374385
else:
375386
# g++ is not present or the user disabled it,
376387
# linker should default to python only.
377-
linker_options = ["py", "vm_nogc"]
388+
linker_options = ["py", "vm", "vm_nogc", "jax", "numba"]
378389
if type(config).cxx.is_default:
379390
# If the user provided an empty value for cxx, do not warn.
380391
_logger.warning(
@@ -388,7 +399,7 @@ def add_compile_configvars():
388399
"linker",
389400
"Default linker used if the pytensor flags mode is Mode",
390401
# Not mutable because the default mode is cached after the first use.
391-
EnumStr("cvm", linker_options, mutable=False),
402+
EnumStr("numba_vm", linker_options, mutable=False),
392403
in_c_key=False,
393404
)
394405

0 commit comments

Comments
 (0)