Skip to content

Cache numba stuff #1326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ jobs:
else
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock;
fi
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
pip install pytest-sphinx
Expand Down
33 changes: 21 additions & 12 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def register_linker(name, linker):
# If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
exclude = []
if not config.cxx:
exclude = ["cxx_only"]

exclude = ["cxx_only", "BlasOpt"]
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
# Even if multiple merge optimizer call will be there, this shouldn't
# impact performance.
Expand Down Expand Up @@ -346,6 +345,11 @@ def __setstate__(self, state):
optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, RewriteDatabaseQuery):
self.provided_optimizer = optimizer

# Force numba-required rewrites if using NumbaLinker
if isinstance(linker, NumbaLinker):
optimizer = optimizer.including("numba")

self._optimizer = optimizer
self.call_time = 0
self.fn_time = 0
Expand Down Expand Up @@ -443,16 +447,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
# string as the key
# Use VM_linker to allow lazy evaluation by default.
FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
NumbaLinker(),
# TODO: Fast_compile should just use python code, CHANGE ME!
RewriteDatabaseQuery(
include=["fast_compile", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)
FAST_RUN = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)
if config.cxx:
FAST_RUN = Mode("cvm", "fast_run")
else:
FAST_RUN = Mode(
"vm",
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
)

NUMBA = Mode(
NumbaLinker(),
Expand Down Expand Up @@ -565,6 +573,7 @@ def register_mode(name, mode):
Add a `Mode` which can be referred to by `name` in `function`.

"""
# TODO: Remove me
if name in predefined_modes:
raise ValueError(f"Mode name already taken: {name}")
predefined_modes[name] = mode
Expand Down
16 changes: 13 additions & 3 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,21 @@ def add_compile_configvars():

if rc == 0 and config.cxx != "":
# Keep the default linker the same as the one for the mode FAST_RUN
linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
linker_options = [
"cvm",
"c|py",
"py",
"c",
"c|py_nogc",
"vm",
"vm_nogc",
"cvm_nogc",
"jax",
]
else:
# g++ is not present or the user disabled it,
# linker should default to python only.
linker_options = ["py", "vm_nogc"]
linker_options = ["py", "vm", "vm_nogc", "jax"]
if type(config).cxx.is_default:
# If the user provided an empty value for cxx, do not warn.
_logger.warning(
Expand All @@ -388,7 +398,7 @@ def add_compile_configvars():
"linker",
"Default linker used if the pytensor flags mode is Mode",
# Not mutable because the default mode is cached after the first use.
EnumStr("cvm", linker_options, mutable=False),
EnumStr("numba", linker_options, mutable=False),
in_c_key=False,
)

Expand Down
3 changes: 3 additions & 0 deletions pytensor/link/jax/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,6 @@ def create_thunk_inputs(self, storage_map):
thunk_inputs.append(sinput)

return thunk_inputs

def __repr__(self):
return "JAXLinker()"
9 changes: 6 additions & 3 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from pytensor.tensor.slinalg import Solve
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import MakeSlice, NoneConst
from pytensor.typed_list import TypedListType


def global_numba_func(func):
Expand Down Expand Up @@ -135,6 +136,8 @@ def get_numba_type(
return CSCMatrixType(numba_dtype)

raise NotImplementedError()
elif isinstance(pytensor_type, TypedListType):
return numba.types.List(get_numba_type(pytensor_type.ttype))
else:
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")

Expand Down Expand Up @@ -481,11 +484,11 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]

func_conditions = [
f"assert x.shape[{i}] == {shape_input_names}"
for i, (shape_input, shape_input_names) in enumerate(
f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'"
for i, (node_dim_input, eval_dim_name) in enumerate(
zip(shape_inputs, shape_input_names, strict=True)
)
if shape_input is not NoneConst
if node_dim_input is not NoneConst
]

func = dedent(
Expand Down
8 changes: 7 additions & 1 deletion pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,12 @@ def axis_apply_fn(x):

@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
# op = getattr(np, str(op.scalar_op).lower())
# @numba_njit
# def elemwise_is_numpy(x):
# return op(x)
# return elemwise_is_numpy

scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs]
scalar_node = op.scalar_op.make_node(*scalar_inputs)

Expand All @@ -276,7 +282,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):

nin = len(node.inputs)
nout = len(node.outputs)
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
core_op_fn = store_core_outputs(scalar_op_fn, op.scalar_op, nin=nin, nout=nout)

input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs)
output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)
Expand Down
16 changes: 16 additions & 0 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Composite,
Identity,
Mul,
Pow,
Reciprocal,
ScalarOp,
Second,
Expand Down Expand Up @@ -154,6 +155,21 @@ def numba_funcify_Switch(op, node, **kwargs):
return numba_basic.global_numba_func(switch)


@numba_funcify.register(Pow)
def numba_funcify_Pow(op, node, **kwargs):
pow_dtype = node.inputs[1].type.dtype

def pow(x, y):
return x**y

# Work-around https://github.com/numba/numba/issues/9554
# fast-math casuse kernel crash
patch_kwargs = {}
if pow_dtype.startswith("int"):
patch_kwargs["fastmath"] = False
return numba_basic.numba_njit(**patch_kwargs)(pow)


def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
"""Create a Numba-compatible N-ary function from a binary function."""
unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_")
Expand Down
4 changes: 2 additions & 2 deletions pytensor/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def numba_funcify_Alloc(op, node, **kwargs):
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
f"{item_name} = to_scalar({shape_name})"
f"{item_name} = {shape_name}.item()"
for item_name, shape_name in zip(
shape_var_item_names, shape_var_names, strict=True
)
Expand All @@ -86,7 +86,7 @@ def numba_funcify_Alloc(op, node, **kwargs):

alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}):
val_np = np.asarray(val)
val_np = val
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
{check_runtime_broadcast_src}
Expand Down
20 changes: 15 additions & 5 deletions pytensor/link/numba/dispatch/vectorize_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import base64
import pickle
from collections.abc import Callable, Sequence
from hashlib import sha256
from textwrap import indent
from typing import Any, cast

Expand All @@ -15,15 +16,19 @@
from numba.core.types.misc import NoneType
from numba.np import arrayobj

from pytensor.graph.op import HasInnerGraph
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.utils import compile_function_src
from pytensor.link.numba.super_utils import compile_function_src2
from pytensor.scalar import ScalarOp


def encode_literals(literals: Sequence) -> str:
return base64.encodebytes(pickle.dumps(literals)).decode()


def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable:
def store_core_outputs(
core_op_fn: Callable, core_op: ScalarOp, nin: int, nout: int
) -> Callable:
"""Create a Numba function that wraps a core function and stores its vectorized outputs.

@njit
Expand Down Expand Up @@ -52,9 +57,14 @@ def store_core_outputs({inp_signature}, {out_signature}):
{indent(store_outputs, " " * 4)}
"""
global_env = {"core_op_fn": core_op_fn}
func = compile_function_src(
func_src, "store_core_outputs", {**globals(), **global_env}
)
# func = compile_function_src(
# func_src, "store_core_outputs", {**globals(), **global_env},
# )
if isinstance(core_op, HasInnerGraph):
key = sha256(core_op.c_code_template.encode()).hexdigest()
else:
key = str(core_op)
func = compile_function_src2(key, func_src, "store_core_outputs", global_env)
return cast(Callable, numba_basic.numba_njit(func))


Expand Down
3 changes: 3 additions & 0 deletions pytensor/link/numba/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ def create_thunk_inputs(self, storage_map):
thunk_inputs.append(sinput)

return thunk_inputs

def __repr__(self):
return "NumbaLinker()"
Loading
Loading