diff --git a/sklearnex/tests/test_common.py b/sklearnex/tests/test_common.py index c6a6bd06c9..b20874f4cf 100644 --- a/sklearnex/tests/test_common.py +++ b/sklearnex/tests/test_common.py @@ -15,12 +15,15 @@ # ============================================================================== import importlib.util +import io import os import pathlib import pkgutil import re import sys import trace +from contextlib import redirect_stdout +from multiprocessing import Pipe, Process, get_context import pytest from sklearn.utils import all_estimators @@ -225,23 +228,137 @@ def _commonpath(inp): _TRACE_BLOCK_LIST = _whitelist_to_blacklist() +def sklearnex_trace(estimator_name, method_name): + """Generate a trace of all function calls in calling estimator.method. + + Parameters + ---------- + estimator_name : str + name of estimator which is a key from PATCHED_MODELS or SPECIAL_INSTANCES + + method_name : str + name of estimator method which is to be traced and stored + + Returns + ------- + text: str + Returns a string output (captured stdout of a python Trace call). It is a + modified version to be more informative, completed by a monkeypatching + of trace._modname. + """ + # get estimator + est = ( + PATCHED_MODELS[estimator_name]() + if estimator_name in PATCHED_MODELS + else SPECIAL_INSTANCES[estimator_name] + ) + + # get dataset + X, y = gen_dataset(est)[0] + # fit dataset if method does not contain 'fit' + if "fit" not in method_name: + est.fit(X, y) + + # monkeypatch new modname for clearer info + orig_modname = trace._modname + try: + # initialize tracer to have a more verbose module naming + # this impacts ignoremods, but it is not used. + trace._modname = _fullpath + tracer = trace.Trace( + count=0, + trace=1, + ignoredirs=_TRACE_BLOCK_LIST, + ) + # call trace on method with dataset + f = io.StringIO() + with redirect_stdout(f): + tracer.runfunc(call_method, est, method_name, X, y) + return f.getvalue() + finally: + trace._modname = orig_modname + + +def _trace_daemon(pipe): + """function interface for the other process. Information + exchanged using a multiprocess.Pipe""" + # a sent value with inherent conversion to False will break + # the while loop and complete the function + while key := pipe.recv(): + try: + text = sklearnex_trace(*key) + except: + # catch all exceptions and pass back, + # this way the process still runs + text = "" + finally: + pipe.send(text) + + +class _FakePipe: + """Minimalistic representation of a multiprocessing.Pipe for test development. + This allows for running sklearnex_trace in the parent process""" + + _text = "" + + def send(self, key): + self._text = sklearnex_trace(*key) + + def recv(self): + return self._text + + +@pytest.fixture(scope="module") +def isolated_trace(): + """Generates a separate python process for isolated sklearnex traces. + + It is a module scope fixture due to the overhead of importing all the + various dependencies and is done once before all the various tests. + Each test will first check a cached value, if not existent it will have + the waiting child process generate the trace and return the text for + caching on its behalf. The isolated process is stopped at test teardown. + + Yields + ------- + pipe_parent: multiprocessing.Connection + one end of a duplex pipe to be used by other pytest fixtures for + communicating with the special isolated tracing python instance + for sklearnex estimators. + """ + # yield _FakePipe() + try: + # force use of 'spawn' to guarantee a clean python environment + # from possible coverage arc tracing + ctx = get_context("spawn") + pipe_parent, pipe_child = ctx.Pipe() + p = ctx.Process(target=_trace_daemon, args=(pipe_child,), daemon=True) + p.start() + yield pipe_parent + finally: + # guarantee closing of the process via a try-catch-finally + # passing False terminates _trace_daemon's loop + pipe_parent.send(False) + pipe_parent.close() + pipe_child.close() + p.join() + p.close() + + @pytest.fixture -def estimator_trace(estimator, method, cache, capsys, monkeypatch): - """Generate a trace of all function calls in calling estimator.method with cache. +def estimator_trace(estimator, method, cache, isolated_trace): + """Create cache of all function calls in calling estimator.method. Parameters ---------- estimator : str - name of estimator which is a key from PATCHED_MODELS or + name of estimator which is a key from PATCHED_MODELS or SPECIAL_INSTANCES method : str name of estimator method which is to be traced and stored cache: pytest.fixture (standard) - capsys: pytest.fixture (standard) - - monkeypatch: pytest.fixture (standard) + isolated_trace: pytest.fixture (test_common.py) Returns ------- @@ -256,31 +373,15 @@ def estimator_trace(estimator, method, cache, capsys, monkeypatch): key = "-".join((str(estimator), method)) flag = cache.get("key", "") != key if flag: - # get estimator - try: - est = PATCHED_MODELS[estimator]() - except KeyError: - est = SPECIAL_INSTANCES[estimator] - - # get dataset - X, y = gen_dataset(est)[0] - # fit dataset if method does not contain 'fit' - if "fit" not in method: - est.fit(X, y) - # initialize tracer to have a more verbose module naming - # this impacts ignoremods, but it is not used. - monkeypatch.setattr(trace, "_modname", _fullpath) - tracer = trace.Trace( - count=0, - trace=1, - ignoredirs=_TRACE_BLOCK_LIST, - ) - # call trace on method with dataset - tracer.runfunc(call_method, est, method, X, y) + isolated_trace.send((estimator, method)) + text = isolated_trace.recv() + # if tracing does not function in isolated_trace, run it in parent process and error + if text == "": + sklearnex_trace(estimator, method) + # guarantee failure if intermittent + assert text, f"sklearnex_trace failure for {estimator}.{method}" - # collect trace for analysis - text = capsys.readouterr().out for modulename, file in _TRACE_ALLOW_DICT.items(): text = text.replace(file, modulename) regex_func = (