Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/nemo-evaluator/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ repository = "https://github.com/NVIDIA-NeMo/Evaluator/packages/nemo-evaluator"
# END(if-changed)

[dependency-groups]
test = ["pytest", "pytest-cov", "pytest-subtests", "pytest-httpserver", "nvidia-simple-evals"]
test = ["pytest", "pytest-asyncio", "pytest-cov", "pytest-subtests", "pytest-httpserver", "nvidia-simple-evals"]

docs = [
"sphinx",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import os
import pathlib
import threading
from typing import Optional, final
import time
from typing import Annotated, Optional, final

import requests
from pydantic import Field
Expand Down Expand Up @@ -48,10 +49,16 @@ class Params(BaseLoggingParams):
default="http://localhost:8000",
description="URL to post the number of processed samples to. Supports expansion of shell variables if present.",
)
progress_tracking_interval: int = Field(
progress_tracking_interval: Annotated[int, Field(gt=0)] = Field(
default=1,
description="How often (every how many samples) to send a progress information.",
)
progress_tracking_interval_seconds: Optional[
Annotated[float | None, Field(gt=0)]
] = Field(
default=None,
description="How often (every N seconds) to send a progress information in addition to progress_tracking_interval.",
)
request_method: str = Field(
default="PATCH",
description="Request method to use for updating the evaluation progress.",
Expand Down Expand Up @@ -83,15 +90,30 @@ def __init__(self, params: Params):
else:
self.progress_filepath = None
self._samples_processed = self._initialize_samples_processed()
self._last_updated_samples_processed = self._samples_processed
self._lock = threading.Lock()

# Get logger for this interceptor with interceptor context
self.logger = get_logger(self.__class__.__name__)

# Optional update on timer
self.progress_tracking_interval_seconds = (
params.progress_tracking_interval_seconds
)
if self.progress_tracking_interval_seconds:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question / concern: since this is asyncio inside multithreaded Flask, could you clarify how this task would behave? Is the asyncio loop executed on in the main Flask thread where the object is created, and would that keep blocking? There are suggestions on SO to have an asyncio on a separate thread to avoid issues with asyncio + threading marriage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick review! I’ve verified asyncio.create_task working with fastapi but have not with flask, this is a good callout to do extra testing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a test on the adapter layer and found that asyncio.create_task errors with no running event loop which tracks with your series of questions about the asyncio loop. Once I reworked the background timer to run as a thread and removed asyncio.sleep, the Flask server seems to running as expected and the timer thread does not block incoming requests. Lmk what you think!

self._timer_stopped = False
self._update_on_timer_thread = threading.Thread(
target=self._update_on_timer,
kwargs={"interval_seconds": self.progress_tracking_interval_seconds},
daemon=True,
)
self._update_on_timer_thread.start()

self.logger.info(
"Progress tracking interceptor initialized",
progress_tracking_url=self.progress_tracking_url,
progress_tracking_interval=self.progress_tracking_interval,
progress_tracking_interval_seconds=self.progress_tracking_interval_seconds,
output_dir=str(self.progress_filepath) if self.progress_filepath else None,
initial_samples_processed=self._samples_processed,
)
Expand Down Expand Up @@ -151,6 +173,34 @@ def _send_progress(self, num_samples: int) -> requests.Response:
samples_processed=num_samples,
)

def _update_on_timer(self, interval_seconds: float):
"""
Sends an update on a timed interval if there has been a change since the last update.
This is a blocking function that is expected to be executed in a thread.
"""
assert interval_seconds > 0
while True:
time.sleep(interval_seconds)
with self._lock:
if self._timer_stopped:
return
if self._last_updated_samples_processed == self._samples_processed:
continue
curr_samples = self._samples_processed

if self.progress_tracking_url is not None:
self._send_progress(curr_samples)
if self.progress_filepath is not None:
self._write_progress(curr_samples)

self.logger.info(
"Progress milestone updated on time interval",
samples_processed=curr_samples,
interval=self.progress_tracking_interval,
)
with self._lock:
self._last_updated_samples_processed = curr_samples

@final
def intercept_response(
self, ar: AdapterResponse, context: AdapterGlobalContext
Expand All @@ -177,13 +227,20 @@ def intercept_response(
samples_processed=curr_samples,
interval=self.progress_tracking_interval,
)
with self._lock:
self._last_updated_samples_processed = curr_samples

return ar

def post_eval_hook(self, context: AdapterGlobalContext) -> None:
self.logger.info(
"Post-eval hook executed", total_samples_processed=self._samples_processed
)
with self._lock:
if self.progress_tracking_interval_seconds:
self._timer_stopped = True
if self._samples_processed == self._last_updated_samples_processed:
return

if self.progress_tracking_url is not None:
self._send_progress(self._samples_processed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
import threading
import time
from typing import List
from unittest.mock import patch

import pytest
import requests
from flask import Flask, request
from pydantic_core import ValidationError

from nemo_evaluator.adapters.interceptors.progress_tracking_interceptor import (
ProgressTrackingInterceptor,
Expand All @@ -30,49 +30,7 @@
AdapterRequestContext,
AdapterResponse,
)


class FakeProgressTrackingServer:
"""Test server to receive progress tracking webhooks."""

def __init__(self, port: int = 8000, request_method="PATCH"):
self.port = port
self.app = Flask(__name__)
self.received_updates: List[dict] = []
self.lock = threading.Lock()

@self.app.route("/", methods=[request_method])
def progress_webhook():
"""Receive progress updates."""
data = request.get_json()
with self.lock:
self.received_updates.append(data)
return {"status": "ok"}

def start(self):
"""Start the server in a background thread."""
self.thread = threading.Thread(
target=self.app.run, kwargs={"host": "0.0.0.0", "port": self.port}
)
self.thread.daemon = True
self.thread.start()
# Give the server time to start
time.sleep(0.5)

def stop(self):
"""Stop the server."""
# Flask doesn't have a clean shutdown, so we'll just let it run as daemon
pass

def get_updates(self) -> List[dict]:
"""Get all received updates."""
with self.lock:
return self.received_updates.copy()

def clear_updates(self):
"""Clear received updates."""
with self.lock:
self.received_updates.clear()
from tests.unit_tests.adapters.testing_utils import FakeProgressTrackingServer


class TestProgressTrackingInterceptor:
Expand Down Expand Up @@ -255,6 +213,19 @@ def test_network_error_handling(self, mock_request):
# Verify that the request was attempted
mock_request.assert_called_once()

def test_interval_configuration_validation(self):
with pytest.raises(ValidationError):
ProgressTrackingInterceptor.Params(
progress_tracking_url="http://test",
progress_tracking_interval=0,
)

with pytest.raises(ValidationError):
ProgressTrackingInterceptor.Params(
progress_tracking_url="http://test",
progress_tracking_interval=-2,
)

def test_interval_configuration(self):
"""Test different interval configurations."""
# Start test server
Expand Down Expand Up @@ -367,6 +338,68 @@ def test_configured_method(self):
finally:
server.stop()

def test_interval_timer_validation(self):
with pytest.raises(ValidationError):
ProgressTrackingInterceptor.Params(
progress_tracking_interval_seconds=-1,
)

@pytest.mark.asyncio
async def test_interval_timer(self):
# Start test server
server = FakeProgressTrackingServer(port=8007)
server.start()

try:
params = ProgressTrackingInterceptor.Params(
progress_tracking_url="http://localhost:8007",
progress_tracking_interval=50,
progress_tracking_interval_seconds=0.2,
)
interceptor = ProgressTrackingInterceptor(params)
assert interceptor.progress_tracking_url == "http://localhost:8007"
assert interceptor.progress_tracking_interval == 50
assert interceptor.progress_tracking_interval_seconds == 0.2

# Create mock response and context
mock_response = AdapterResponse(
r=requests.Response(),
rctx=AdapterRequestContext(),
)
context = AdapterGlobalContext(output_dir="/tmp", url="http://test")

# Verify no update until timer interval
interceptor.intercept_response(mock_response, context)
interceptor.intercept_response(mock_response, context)
updates = server.get_updates()
assert len(updates) == 0, "no updates until timer interval"

# Verify first timer interval calls update
await asyncio.sleep(0.5)
updates = server.get_updates()
assert len(updates) == 1, "only expected one update"
assert updates[0]["samples_processed"] == 2

# Verify subsequent timer interval calls update
interceptor.intercept_response(mock_response, context)
await asyncio.sleep(0.5)
updates = server.get_updates()
assert len(updates) == 2, "expected second update"
assert updates[1]["samples_processed"] == 3

# No calls to update after timer is stopped
interceptor.post_eval_hook(context)
interceptor.intercept_response(mock_response, context)
assert interceptor._samples_processed == 4
await asyncio.sleep(0.5)
updates = server.get_updates()
assert len(updates) == 2, (
"expected post_eval_hook to skip posting update on no change and no updates after post_eval_hook cancels timed updates"
)

finally:
server.stop()


if __name__ == "__main__":
# Simple test runner for manual testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import threading
from typing import Any, Generator
from unittest.mock import patch

Expand All @@ -32,6 +34,7 @@
EvaluationTarget,
)
from tests.unit_tests.adapters.testing_utils import (
FakeProgressTrackingServer,
create_fake_endpoint_process,
)

Expand Down Expand Up @@ -742,3 +745,79 @@ def test_adapter_server_process_raise_on_port_taken():
):
with AdapterServerProcess(evaluation):
pass


@pytest.mark.asyncio
async def test_adapter_with_progress_tracking_timer(fake_openai_endpoint, tmp_path):
# Setup progress tracking server to verify updates are non-blocking
progress_tracking_server = FakeProgressTrackingServer(port=8011)
progress_tracking_server.start()
progress_tracking_url = "http://localhost:8011"
progress_tracking_config = dict(
name="progress_tracking",
enabled=True,
config={
"progress_tracking_url": progress_tracking_url,
# number of requests for the test are lower than interval,
# expect all updates are from the timer thread.
"progress_tracking_interval": 500,
"progress_tracking_interval_seconds": 0.1,
},
)

# Start adapter server
evaluation = Evaluation(
command="",
framework_name="",
pkg_name="",
config=EvaluationConfig(output_dir=str(tmp_path)),
target=EvaluationTarget(
api_endpoint=ApiEndpoint(
url="http://localhost:3300/v1/chat/completions",
adapter_config=AdapterConfig(
interceptors=[
dict(
name="endpoint",
config={},
),
progress_tracking_config,
],
post_eval_hooks=[progress_tracking_config],
),
),
),
)
with AdapterServerProcess(evaluation) as adapter_server_process:
# Wait for server to be ready
wait_for_server("localhost", adapter_server_process.port)
url = f"http://localhost:{adapter_server_process.port}"

data = {
"prompt": "This is a test prompt",
"max_tokens": 100,
"temperature": 0.5,
}

def concurrent_requests():
for _ in range(10):
requests.post(url, json=data)

threads = []
for i in range(10):
thread = threading.Thread(target=concurrent_requests, daemon=True)
threads.append(thread)
thread.start()

# Wait for all threads to complete
for thread in threads:
thread.join()

await asyncio.sleep(0.5)
updates = progress_tracking_server.get_updates()

# There can be multiple updates depending on monitoring loop wrt to number
# of concurrent calls and sleep, so we only test that there is at least one update.
assert len(updates) > 0, "expected at least one update within duration"
assert updates[-1] == {"samples_processed": 100}, (
"the last update should have all the samples processed recorded"
)
Loading