diff --git a/distributed/diagnostics/tests/test_cudf_diagnostics.py b/distributed/diagnostics/tests/test_cudf_diagnostics.py index fd252590ab3..d1be17eb41b 100644 --- a/distributed/diagnostics/tests/test_cudf_diagnostics.py +++ b/distributed/diagnostics/tests/test_cudf_diagnostics.py @@ -1,25 +1,36 @@ from __future__ import annotations -import asyncio -import os +import sys import pytest -from distributed.utils_test import gen_cluster +import dask +from dask.distributed import worker -pytestmark = [ - pytest.mark.gpu, - pytest.mark.skipif( - os.environ.get("CUDF_SPILL", "off") != "on" - or os.environ.get("CUDF_SPILL_STATS", "0") != "1" - or os.environ.get("DASK_DISTRIBUTED__DIAGNOSTICS__CUDF", "0") != "1", - reason="cuDF spill stats monitoring must be enabled manually", - ), -] +from distributed.utils_test import async_poll_for, gen_cluster cudf = pytest.importorskip("cudf") +@pytest.fixture +def cudf_spill(): + """ + Configures cuDF options to enable spilling. + + Returns the settings to their original values after the test. + """ + spill = cudf.get_option("spill") + spill_stats = cudf.get_option("spill_stats") + + cudf.set_option("spill", True) + cudf.set_option("spill_stats", 1) + + yield + + cudf.set_option("spill", spill) + cudf.set_option("spill_stats", spill_stats) + + def force_spill(): from cudf.core.buffer.spill_manager import get_global_manager @@ -37,7 +48,13 @@ def force_spill(): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1)], + # whether worker.cudf_metric is in DEFAULT_METRICS depends on the value + # of distributed.diagnostics.cudf when distributed.worker is imported. + worker_kwargs={ + "metrics": {**worker.DEFAULT_METRICS, "cudf": worker.cudf_metric}, + }, ) +@pytest.mark.usefixtures("cudf_spill") async def test_cudf_metrics(c, s, *workers): w = list(s.workers.values())[0] assert "cudf" in w.metrics @@ -45,7 +62,13 @@ async def test_cudf_metrics(c, s, *workers): spill_totals = (await c.run(force_spill, workers=[w.address]))[w.address] assert spill_totals > 0 - # We have to wait for the worker's metrics to update. - # TODO: avoid sleep, is it possible to wait on the next update of metrics? - await asyncio.sleep(1) + await async_poll_for(lambda: w.metrics["cudf"]["cudf-spilled"] > 0, timeout=2) assert w.metrics["cudf"]["cudf-spilled"] == spill_totals + + +def test_cudf_default_metrics(monkeypatch): + with dask.config.set(**{"distributed.diagnostics.cudf": 1}): + del sys.modules["distributed.worker"] + import distributed.worker + + assert "cudf" in distributed.worker.DEFAULT_METRICS diff --git a/distributed/worker.py b/distributed/worker.py index c8a4bbd574e..8a57f563005 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3224,21 +3224,17 @@ async def rmm_metric(worker): DEFAULT_METRICS["rmm"] = rmm_metric del _rmm -# avoid importing cuDF unless explicitly enabled -if dask.config.get("distributed.diagnostics.cudf"): - try: - import cudf as _cudf # noqa: F401 - except Exception: - pass - else: - from distributed.diagnostics import cudf - async def cudf_metric(worker): - result = await offload(cudf.real_time) - return result +async def cudf_metric(worker): + # avoid importing optional cudf at the top-level + from distributed.diagnostics import cudf - DEFAULT_METRICS["cudf"] = cudf_metric - del _cudf + result = await offload(cudf.real_time) + return result + + +if dask.config.get("distributed.diagnostics.cudf"): + DEFAULT_METRICS["cudf"] = cudf_metric def print(