diff --git a/BackendBench/suite/base.py b/BackendBench/suite/base.py index 1f5fe635..eaeec0df 100644 --- a/BackendBench/suite/base.py +++ b/BackendBench/suite/base.py @@ -5,6 +5,9 @@ # LICENSE file in the root directory of this source tree. +import importlib + + class Test: def __init__(self, *args, **kwargs): self._args = args @@ -25,6 +28,23 @@ def __init__(self, op, correctness_tests, performance_tests): self.correctness_tests = correctness_tests self.performance_tests = performance_tests + def __getstate__(self): + # Custom serialization to handle callable op + state = self.__dict__.copy() + if callable(state.get("op")): + op = state.pop("op") + state["op_name"] = op.__name__ + state["op_module"] = op.__module__ + return state + + def __setstate__(self, state): + if "op_name" in state and "op_module" in state: + op_name = state.pop("op_name") + op_module = state.pop("op_module") + module = importlib.import_module(op_module) + state["op"] = getattr(module, op_name) + self.__dict__.update(state) + class TestSuite: def __init__(self, name, optests):