Skip to content

Commit

Permalink
[testing, CI] fix coverage statistics issue caused by ```test_common.…
Browse files Browse the repository at this point in the history
…py``` tracer patching (#2237)

* attempt 1 at fixing the issue

* probe error'

* fix basic mistake

* sneaky skips

* Update test_common.py

* another mistake

* Revert "another mistake"

This reverts commit 5c602da.

* switch to fixture

* fix and optimize

* attempt with multiprocessing

* fix errors

* add comments

* add missing text

* Update test_common.py

* Apply suggestions from code review

Co-authored-by: Andreas Huber <[email protected]>

* Update sklearnex/tests/test_common.py

Co-authored-by: Andreas Huber <[email protected]>

* Update test_common.py

* switch based on recommendation

* Update test_common.py

* revert to test

* Introduce infrastructure to run on main process

* Update test_common.py

* Update test_common.py

* Update test_common.py

* Update test_common.py

* Update test_common.py

* Update test_common.py

* Update test_common.py

* Update test_common.py

* Update test_common.py

* Update test_common.py

* Update test_common.py

---------

Co-authored-by: Andreas Huber <[email protected]>
  • Loading branch information
icfaust and ahuber21 authored Jan 15, 2025
1 parent c76273d commit 6742689
Showing 1 changed file with 130 additions and 29 deletions.
159 changes: 130 additions & 29 deletions sklearnex/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
# ==============================================================================

import importlib.util
import io
import os
import pathlib
import pkgutil
import re
import sys
import trace
from contextlib import redirect_stdout
from multiprocessing import Pipe, Process, get_context

import pytest
from sklearn.utils import all_estimators
Expand Down Expand Up @@ -225,23 +228,137 @@ def _commonpath(inp):
_TRACE_BLOCK_LIST = _whitelist_to_blacklist()


def sklearnex_trace(estimator_name, method_name):
"""Generate a trace of all function calls in calling estimator.method.
Parameters
----------
estimator_name : str
name of estimator which is a key from PATCHED_MODELS or SPECIAL_INSTANCES
method_name : str
name of estimator method which is to be traced and stored
Returns
-------
text: str
Returns a string output (captured stdout of a python Trace call). It is a
modified version to be more informative, completed by a monkeypatching
of trace._modname.
"""
# get estimator
est = (
PATCHED_MODELS[estimator_name]()
if estimator_name in PATCHED_MODELS
else SPECIAL_INSTANCES[estimator_name]
)

# get dataset
X, y = gen_dataset(est)[0]
# fit dataset if method does not contain 'fit'
if "fit" not in method_name:
est.fit(X, y)

# monkeypatch new modname for clearer info
orig_modname = trace._modname
try:
# initialize tracer to have a more verbose module naming
# this impacts ignoremods, but it is not used.
trace._modname = _fullpath
tracer = trace.Trace(
count=0,
trace=1,
ignoredirs=_TRACE_BLOCK_LIST,
)
# call trace on method with dataset
f = io.StringIO()
with redirect_stdout(f):
tracer.runfunc(call_method, est, method_name, X, y)
return f.getvalue()
finally:
trace._modname = orig_modname


def _trace_daemon(pipe):
"""function interface for the other process. Information
exchanged using a multiprocess.Pipe"""
# a sent value with inherent conversion to False will break
# the while loop and complete the function
while key := pipe.recv():
try:
text = sklearnex_trace(*key)
except:
# catch all exceptions and pass back,
# this way the process still runs
text = ""
finally:
pipe.send(text)


class _FakePipe:
"""Minimalistic representation of a multiprocessing.Pipe for test development.
This allows for running sklearnex_trace in the parent process"""

_text = ""

def send(self, key):
self._text = sklearnex_trace(*key)

def recv(self):
return self._text


@pytest.fixture(scope="module")
def isolated_trace():
"""Generates a separate python process for isolated sklearnex traces.
It is a module scope fixture due to the overhead of importing all the
various dependencies and is done once before all the various tests.
Each test will first check a cached value, if not existent it will have
the waiting child process generate the trace and return the text for
caching on its behalf. The isolated process is stopped at test teardown.
Yields
-------
pipe_parent: multiprocessing.Connection
one end of a duplex pipe to be used by other pytest fixtures for
communicating with the special isolated tracing python instance
for sklearnex estimators.
"""
# yield _FakePipe()
try:
# force use of 'spawn' to guarantee a clean python environment
# from possible coverage arc tracing
ctx = get_context("spawn")
pipe_parent, pipe_child = ctx.Pipe()
p = ctx.Process(target=_trace_daemon, args=(pipe_child,), daemon=True)
p.start()
yield pipe_parent
finally:
# guarantee closing of the process via a try-catch-finally
# passing False terminates _trace_daemon's loop
pipe_parent.send(False)
pipe_parent.close()
pipe_child.close()
p.join()
p.close()


@pytest.fixture
def estimator_trace(estimator, method, cache, capsys, monkeypatch):
"""Generate a trace of all function calls in calling estimator.method with cache.
def estimator_trace(estimator, method, cache, isolated_trace):
"""Create cache of all function calls in calling estimator.method.
Parameters
----------
estimator : str
name of estimator which is a key from PATCHED_MODELS or
name of estimator which is a key from PATCHED_MODELS or SPECIAL_INSTANCES
method : str
name of estimator method which is to be traced and stored
cache: pytest.fixture (standard)
capsys: pytest.fixture (standard)
monkeypatch: pytest.fixture (standard)
isolated_trace: pytest.fixture (test_common.py)
Returns
-------
Expand All @@ -256,31 +373,15 @@ def estimator_trace(estimator, method, cache, capsys, monkeypatch):
key = "-".join((str(estimator), method))
flag = cache.get("key", "") != key
if flag:
# get estimator
try:
est = PATCHED_MODELS[estimator]()
except KeyError:
est = SPECIAL_INSTANCES[estimator]

# get dataset
X, y = gen_dataset(est)[0]
# fit dataset if method does not contain 'fit'
if "fit" not in method:
est.fit(X, y)

# initialize tracer to have a more verbose module naming
# this impacts ignoremods, but it is not used.
monkeypatch.setattr(trace, "_modname", _fullpath)
tracer = trace.Trace(
count=0,
trace=1,
ignoredirs=_TRACE_BLOCK_LIST,
)
# call trace on method with dataset
tracer.runfunc(call_method, est, method, X, y)
isolated_trace.send((estimator, method))
text = isolated_trace.recv()
# if tracing does not function in isolated_trace, run it in parent process and error
if text == "":
sklearnex_trace(estimator, method)
# guarantee failure if intermittent
assert text, f"sklearnex_trace failure for {estimator}.{method}"

# collect trace for analysis
text = capsys.readouterr().out
for modulename, file in _TRACE_ALLOW_DICT.items():
text = text.replace(file, modulename)
regex_func = (
Expand Down

0 comments on commit 6742689

Please sign in to comment.