From 1bb444312f98eabc84e3b186cfc529f387b13d6b Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:45:28 -0500 Subject: [PATCH 1/7] Initial scaffold of mock OpenAI-compatible server --- .../benchmark/mock_llm_server/__init__.py | 14 + .../mock_llm_server/example_usage.py | 206 +++++++ .../mock_llm_server/mock_llm_server.py | 406 +++++++++++++ .../benchmark/mock_llm_server/run_server.py | 79 +++ tests/benchmark/test_mock_llm_server.py | 531 ++++++++++++++++++ 5 files changed, 1236 insertions(+) create mode 100644 nemoguardrails/benchmark/mock_llm_server/__init__.py create mode 100644 nemoguardrails/benchmark/mock_llm_server/example_usage.py create mode 100644 nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py create mode 100644 nemoguardrails/benchmark/mock_llm_server/run_server.py create mode 100644 tests/benchmark/test_mock_llm_server.py diff --git a/nemoguardrails/benchmark/mock_llm_server/__init__.py b/nemoguardrails/benchmark/mock_llm_server/__init__.py new file mode 100644 index 000000000..9ba9d4310 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemoguardrails/benchmark/mock_llm_server/example_usage.py b/nemoguardrails/benchmark/mock_llm_server/example_usage.py new file mode 100644 index 000000000..278ab8d94 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/example_usage.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage of the Mock LLM Server. + +This script demonstrates how to interact with the running mock server +using standard HTTP requests and the OpenAI Python client. +""" + +import json +import time + +import requests + + +def test_with_requests(): + """Test the server using the requests library.""" + base_url = "http://localhost:8000" + + print("Testing Mock LLM Server with requests library...") + print("=" * 50) + + # Test health endpoint + try: + response = requests.get(f"{base_url}/health", timeout=5) + print(f"Health check: {response.status_code} - {response.json()}") + except requests.RequestException as e: + print(f"Health check failed: {e}") + print("Make sure the server is running: python run_server.py") + return + + # Test models endpoint + try: + response = requests.get(f"{base_url}/v1/models", timeout=5) + print(f"\\nModels: {response.status_code}") + models_data = response.json() + for model in models_data["data"]: + print(f" - {model['id']}") + except requests.RequestException as e: + print(f"Models request failed: {e}") + + # Test chat completion + try: + chat_payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + } + response = requests.post( + f"{base_url}/v1/chat/completions", + json=chat_payload, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + print(f"\\nChat completion: {response.status_code}") + if response.status_code == 200: + data = response.json() + print(f"Response: {data['choices'][0]['message']['content']}") + print(f"Usage: {data['usage']}") + except requests.RequestException as e: + print(f"Chat completion failed: {e}") + + # Test text completion + try: + completion_payload = { + "model": "text-davinci-003", + "prompt": "The capital of France is", + "max_tokens": 10, + } + response = requests.post( + f"{base_url}/v1/completions", + json=completion_payload, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + print(f"\\nText completion: {response.status_code}") + if response.status_code == 200: + data = response.json() + print(f"Response: {data['choices'][0]['text']}") + print(f"Usage: {data['usage']}") + except requests.RequestException as e: + print(f"Text completion failed: {e}") + + +def test_with_openai_client(): + """Test the server using the OpenAI Python client.""" + try: + import openai + except ImportError: + print("\\nOpenAI client not available. Install with: pip install openai") + return + + print("\\n" + "=" * 50) + print("Testing with OpenAI client library...") + print("=" * 50) + + # Configure client to use local server + client = openai.OpenAI( + base_url="http://localhost:8000/v1", + api_key="dummy-key", # Server doesn't validate, but client requires it + ) + + try: + # Test chat completion + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello from OpenAI client!"}], + ) + print(f"Chat completion response: {response.choices[0].message.content}") + print( + f"Usage: prompt={response.usage.prompt_tokens}, completion={response.usage.completion_tokens}" + ) + + # Test text completion (if supported by client version) + try: + response = client.completions.create( + model="text-davinci-003", prompt="OpenAI client test: ", max_tokens=10 + ) + print(f"Text completion response: {response.choices[0].text}") + except Exception as e: + print(f"Text completion not supported in this OpenAI client version: {e}") + + except Exception as e: + print(f"OpenAI client test failed: {e}") + + +def benchmark_performance(): + """Simple performance benchmark.""" + print("\\n" + "=" * 50) + print("Performance Benchmark") + print("=" * 50) + + base_url = "http://localhost:8000" + num_requests = 10 + + chat_payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Benchmark test"}], + "max_tokens": 20, + } + + print(f"Making {num_requests} chat completion requests...") + + start_time = time.time() + successful_requests = 0 + + for i in range(num_requests): + try: + response = requests.post( + f"{base_url}/v1/chat/completions", + json=chat_payload, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + if response.status_code == 200: + successful_requests += 1 + except requests.RequestException: + pass + + end_time = time.time() + total_time = end_time - start_time + + print(f"Results:") + print(f" Total requests: {num_requests}") + print(f" Successful requests: {successful_requests}") + print(f" Total time: {total_time:.2f} seconds") + print(f" Average time per request: {total_time/num_requests:.3f} seconds") + print(f" Requests per second: {num_requests/total_time:.2f}") + + +def main(): + """Main function to run all tests.""" + print("Mock LLM Server Example Usage") + print("=" * 50) + print("Make sure the server is running before running this script:") + print(" python run_server.py") + print() + + # Test with requests + test_with_requests() + + # Test with OpenAI client + test_with_openai_client() + + # Simple benchmark + benchmark_performance() + + print("\\nExample completed!") + + +if __name__ == "__main__": + main() diff --git a/nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py b/nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py new file mode 100644 index 000000000..28725e724 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py @@ -0,0 +1,406 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mock LLM FastAPI Server with OpenAI-compatible interface. + +This server provides dummy implementations of OpenAI API endpoints for testing +and benchmarking purposes. +""" + +import time +import uuid +from typing import Any, Dict, List, Optional, Union + +import uvicorn +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, Field + +app = FastAPI( + title="Mock LLM Server", + description="OpenAI-compatible mock LLM server for testing and benchmarking", + version="1.0.0", +) + + +# Pydantic Models for Request/Response validation + + +class Message(BaseModel): + """Chat message model.""" + + role: str = Field(..., description="The role of the message author") + content: str = Field(..., description="The content of the message") + name: Optional[str] = Field(None, description="The name of the author") + + +class ChatCompletionRequest(BaseModel): + """Chat completion request model.""" + + model: str = Field(..., description="ID of the model to use") + messages: List[Message] = Field( + ..., description="List of messages comprising the conversation" + ) + max_tokens: Optional[int] = Field( + None, description="Maximum number of tokens to generate", ge=1 + ) + temperature: Optional[float] = Field( + 1.0, description="Sampling temperature", ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 + ) + n: Optional[int] = Field( + 1, description="Number of completions to generate", ge=1, le=128 + ) + stream: Optional[bool] = Field( + False, description="Whether to stream back partial progress" + ) + stop: Optional[Union[str, List[str]]] = Field( + None, description="Sequences where the API will stop generating" + ) + presence_penalty: Optional[float] = Field( + 0.0, description="Presence penalty", ge=-2.0, le=2.0 + ) + frequency_penalty: Optional[float] = Field( + 0.0, description="Frequency penalty", ge=-2.0, le=2.0 + ) + logit_bias: Optional[Dict[str, float]] = Field( + None, description="Modify likelihood of specified tokens" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + + +class CompletionRequest(BaseModel): + """Text completion request model.""" + + model: str = Field(..., description="ID of the model to use") + prompt: Union[str, List[str]] = Field( + ..., description="The prompt(s) to generate completions for" + ) + max_tokens: Optional[int] = Field( + 16, description="Maximum number of tokens to generate", ge=1 + ) + temperature: Optional[float] = Field( + 1.0, description="Sampling temperature", ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 + ) + n: Optional[int] = Field( + 1, description="Number of completions to generate", ge=1, le=128 + ) + stream: Optional[bool] = Field( + False, description="Whether to stream back partial progress" + ) + logprobs: Optional[int] = Field( + None, description="Include log probabilities", ge=0, le=5 + ) + echo: Optional[bool] = Field( + False, description="Echo back the prompt in addition to completion" + ) + stop: Optional[Union[str, List[str]]] = Field( + None, description="Sequences where the API will stop generating" + ) + presence_penalty: Optional[float] = Field( + 0.0, description="Presence penalty", ge=-2.0, le=2.0 + ) + frequency_penalty: Optional[float] = Field( + 0.0, description="Frequency penalty", ge=-2.0, le=2.0 + ) + best_of: Optional[int] = Field( + 1, description="Number of completions to generate server-side", ge=1 + ) + logit_bias: Optional[Dict[str, float]] = Field( + None, description="Modify likelihood of specified tokens" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + + +class Usage(BaseModel): + """Token usage information.""" + + prompt_tokens: int = Field(..., description="Number of tokens in the prompt") + completion_tokens: int = Field( + ..., description="Number of tokens in the completion" + ) + total_tokens: int = Field(..., description="Total number of tokens used") + + +class ChatCompletionChoice(BaseModel): + """Chat completion choice.""" + + index: int = Field(..., description="The index of this choice") + message: Message = Field(..., description="The generated message") + finish_reason: str = Field( + ..., description="The reason the model stopped generating" + ) + + +class CompletionChoice(BaseModel): + """Text completion choice.""" + + text: str = Field(..., description="The generated text") + index: int = Field(..., description="The index of this choice") + logprobs: Optional[Dict[str, Any]] = Field( + None, description="Log probability information" + ) + finish_reason: str = Field( + ..., description="The reason the model stopped generating" + ) + + +class ChatCompletionResponse(BaseModel): + """Chat completion response.""" + + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("chat.completion", description="Object type") + created: int = Field( + ..., description="Unix timestamp when the completion was created" + ) + model: str = Field(..., description="The model used for completion") + choices: List[ChatCompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: Usage = Field(..., description="Token usage information") + + +class CompletionResponse(BaseModel): + """Text completion response.""" + + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("text_completion", description="Object type") + created: int = Field( + ..., description="Unix timestamp when the completion was created" + ) + model: str = Field(..., description="The model used for completion") + choices: List[CompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: Usage = Field(..., description="Token usage information") + + +class Model(BaseModel): + """Model information.""" + + id: str = Field(..., description="Model identifier") + object: str = Field("model", description="Object type") + created: int = Field(..., description="Unix timestamp when the model was created") + owned_by: str = Field(..., description="Organization that owns the model") + + +class ModelsResponse(BaseModel): + """Models list response.""" + + object: str = Field("list", description="Object type") + data: List[Model] = Field(..., description="List of available models") + + +# Dummy data and helper functions + +DUMMY_MODELS = [ + { + "id": "gpt-3.5-turbo", + "object": "model", + "created": 1677610602, + "owned_by": "openai", + }, + {"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"}, + { + "id": "gpt-4-turbo", + "object": "model", + "created": 1712361441, + "owned_by": "openai", + }, + { + "id": "text-davinci-003", + "object": "model", + "created": 1669599635, + "owned_by": "openai", + }, +] + +DUMMY_CHAT_RESPONSES = [ + "This is a mock response from the LLM server.", + "I'm a dummy AI assistant created for testing purposes.", + "This response is generated by a mock OpenAI-compatible server.", + "Hello! I'm responding with dummy data for benchmarking.", + "This is a simulated conversation response for testing.", +] + +DUMMY_COMPLETION_RESPONSES = [ + " This is a dummy text completion.", + " Here's some mock generated text.", + " This is a sample completion response.", + " Mock completion text for testing purposes.", + " Dummy text generated by the mock server.", +] + + +def generate_id(prefix: str = "chatcmpl") -> str: + """Generate a unique ID for completions.""" + return f"{prefix}-{uuid.uuid4().hex[:8]}" + + +def calculate_tokens(text: str) -> int: + """Rough token calculation (approximately 4 characters per token).""" + return max(1, len(text) // 4) + + +def get_dummy_chat_response() -> str: + """Get a dummy chat response.""" + import random + + return random.choice(DUMMY_CHAT_RESPONSES) + + +def get_dummy_completion_response() -> str: + """Get a dummy completion response.""" + import random + + return random.choice(DUMMY_COMPLETION_RESPONSES) + + +# API Endpoints + + +@app.get("/") +async def root(): + """Root endpoint with basic server information.""" + return { + "message": "Mock LLM Server", + "version": "1.0.0", + "description": "OpenAI-compatible mock LLM server for testing and benchmarking", + "endpoints": ["/v1/models", "/v1/chat/completions", "/v1/completions"], + } + + +@app.get("/v1/models", response_model=ModelsResponse) +async def list_models(): + """List available models.""" + return ModelsResponse( + object="list", data=[Model(**model) for model in DUMMY_MODELS] + ) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def chat_completions(request: ChatCompletionRequest): + """Create a chat completion.""" + # Validate model exists + available_models = [model["id"] for model in DUMMY_MODELS] + if request.model not in available_models: + raise HTTPException( + status_code=400, + detail=f"Model '{request.model}' not found. Available models: {available_models}", + ) + + # Generate dummy response + response_content = get_dummy_chat_response() + + # Calculate token usage + prompt_text = " ".join([msg.content for msg in request.messages]) + prompt_tokens = calculate_tokens(prompt_text) + completion_tokens = calculate_tokens(response_content) + + # Create response + completion_id = generate_id("chatcmpl") + created_timestamp = int(time.time()) + + choices = [] + for i in range(request.n or 1): + choice = ChatCompletionChoice( + index=i, + message=Message(role="assistant", content=response_content, name=None), + finish_reason="stop", + ) + choices.append(choice) + + return ChatCompletionResponse( + id=completion_id, + object="chat.completion", + created=created_timestamp, + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + +@app.post("/v1/completions", response_model=CompletionResponse) +async def completions(request: CompletionRequest): + """Create a text completion.""" + # Validate model exists + available_models = [model["id"] for model in DUMMY_MODELS] + if request.model not in available_models: + raise HTTPException( + status_code=400, + detail=f"Model '{request.model}' not found. Available models: {available_models}", + ) + + # Handle prompt (can be string or list) + if isinstance(request.prompt, list): + prompt_text = " ".join(request.prompt) + else: + prompt_text = request.prompt + + # Generate dummy response + response_text = get_dummy_completion_response() + + # Calculate token usage + prompt_tokens = calculate_tokens(prompt_text) + completion_tokens = calculate_tokens(response_text) + + # Create response + completion_id = generate_id("cmpl") + created_timestamp = int(time.time()) + + choices = [] + for i in range(request.n or 1): + choice = CompletionChoice( + text=response_text, index=i, logprobs=None, finish_reason="stop" + ) + choices.append(choice) + + return CompletionResponse( + id=completion_id, + object="text_completion", + created=created_timestamp, + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy", "timestamp": int(time.time())} + + +if __name__ == "__main__": + uvicorn.run( + "mock_llm_server:app", host="0.0.0.0", port=8000, reload=True, log_level="info" + ) diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py new file mode 100644 index 000000000..66f281932 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Startup script for the Mock LLM Server. + +This script starts the FastAPI server with configurable host and port settings. +""" + +import argparse +import os +import sys + +import uvicorn + +# Add the current directory to Python path to import the server module +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + + +def main(): + parser = argparse.ArgumentParser(description="Run the Mock LLM Server") + parser.add_argument( + "--host", + default="0.0.0.0", + help="Host to bind the server to (default: 0.0.0.0)", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to bind the server to (default: 8000)", + ) + parser.add_argument( + "--reload", action="store_true", help="Enable auto-reload for development" + ) + parser.add_argument( + "--log-level", + default="info", + choices=["critical", "error", "warning", "info", "debug", "trace"], + help="Log level (default: info)", + ) + + args = parser.parse_args() + + print(f"Starting Mock LLM Server on {args.host}:{args.port}") + print(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs") + print(f"Health check at: http://{args.host}:{args.port}/health") + print("Press Ctrl+C to stop the server") + + try: + uvicorn.run( + "mock_llm_server:app", + host=args.host, + port=args.port, + reload=args.reload, + log_level=args.log_level, + ) + except KeyboardInterrupt: + print("\nServer stopped by user") + except Exception as e: # pylint: disable=broad-except + print(f"Error starting server: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/benchmark/test_mock_llm_server.py b/tests/benchmark/test_mock_llm_server.py new file mode 100644 index 000000000..b74f9633d --- /dev/null +++ b/tests/benchmark/test_mock_llm_server.py @@ -0,0 +1,531 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for the Mock LLM FastAPI Server. + +This module contains comprehensive tests for all endpoints and edge cases +of the OpenAI-compatible mock LLM server. +""" + +import json +import time +from typing import Any, Dict, List +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +# Import the server and its components +from mock_llm_server.mock_llm_server import ( + DUMMY_MODELS, + app, + calculate_tokens, + generate_id, + get_dummy_chat_response, + get_dummy_completion_response, +) + + +class TestMockLLMServer: + """Test class for the Mock LLM Server.""" + + @pytest.fixture + def client(self): + """Create a test client for the FastAPI app.""" + return TestClient(app) + + @pytest.fixture + def valid_chat_request(self): + """Sample valid chat completion request.""" + return { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "temperature": 0.7, + } + + @pytest.fixture + def valid_completion_request(self): + """Sample valid text completion request.""" + return { + "model": "text-davinci-003", + "prompt": "The capital of France is", + "max_tokens": 10, + "temperature": 0.8, + } + + # Root endpoint tests + def test_root_endpoint(self, client): + """Test the root endpoint returns correct information.""" + response = client.get("/") + assert response.status_code == 200 + + data = response.json() + assert data["message"] == "Mock LLM Server" + assert data["version"] == "1.0.0" + assert "description" in data + assert "/v1/models" in data["endpoints"] + assert "/v1/chat/completions" in data["endpoints"] + assert "/v1/completions" in data["endpoints"] + + # Health check tests + def test_health_check(self, client): + """Test the health check endpoint.""" + response = client.get("/health") + assert response.status_code == 200 + + data = response.json() + assert data["status"] == "healthy" + assert "timestamp" in data + assert isinstance(data["timestamp"], int) + + # Models endpoint tests + def test_list_models(self, client): + """Test the models listing endpoint.""" + response = client.get("/v1/models") + assert response.status_code == 200 + + data = response.json() + assert data["object"] == "list" + assert isinstance(data["data"], list) + assert len(data["data"]) == len(DUMMY_MODELS) + + # Check first model structure + model = data["data"][0] + assert "id" in model + assert "object" in model + assert "created" in model + assert "owned_by" in model + assert model["object"] == "model" + + def test_models_contain_expected_models(self, client): + """Test that all expected models are returned.""" + response = client.get("/v1/models") + data = response.json() + + model_ids = [model["id"] for model in data["data"]] + expected_ids = [model["id"] for model in DUMMY_MODELS] + + assert set(model_ids) == set(expected_ids) + + # Chat completions tests + def test_chat_completions_success(self, client, valid_chat_request): + """Test successful chat completion request.""" + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 200 + + data = response.json() + assert data["object"] == "chat.completion" + assert data["model"] == valid_chat_request["model"] + assert "id" in data + assert "created" in data + assert isinstance(data["created"], int) + + # Check choices + assert "choices" in data + assert len(data["choices"]) == 1 + choice = data["choices"][0] + assert choice["index"] == 0 + assert choice["finish_reason"] == "stop" + assert "message" in choice + assert choice["message"]["role"] == "assistant" + assert isinstance(choice["message"]["content"], str) + + # Check usage + assert "usage" in data + usage = data["usage"] + assert "prompt_tokens" in usage + assert "completion_tokens" in usage + assert "total_tokens" in usage + assert ( + usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + ) + + def test_chat_completions_multiple_choices(self, client, valid_chat_request): + """Test chat completion with multiple choices.""" + valid_chat_request["n"] = 3 + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 200 + + data = response.json() + assert len(data["choices"]) == 3 + + for i, choice in enumerate(data["choices"]): + assert choice["index"] == i + assert choice["finish_reason"] == "stop" + + def test_chat_completions_invalid_model(self, client, valid_chat_request): + """Test chat completion with invalid model.""" + valid_chat_request["model"] = "invalid-model" + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 400 + + data = response.json() + assert "detail" in data + assert "invalid-model" in data["detail"] + assert "not found" in data["detail"] + + def test_chat_completions_empty_messages(self, client): + """Test chat completion with empty messages.""" + request_data = { + "model": "gpt-3.5-turbo", + "messages": [], + } + response = client.post("/v1/chat/completions", json=request_data) + # Note: The server currently accepts empty messages and processes them + # This may be acceptable behavior for a mock server + assert response.status_code in [ + 200, + 422, + ] # Allow both success and validation error + + def test_chat_completions_invalid_message_format(self, client): + """Test chat completion with invalid message format.""" + request_data = { + "model": "gpt-3.5-turbo", + "messages": [{"invalid": "format"}], + } + response = client.post("/v1/chat/completions", json=request_data) + assert response.status_code == 422 # Validation error + + def test_chat_completions_parameter_validation(self, client, valid_chat_request): + """Test parameter validation for chat completions.""" + # Test max_tokens validation + valid_chat_request["max_tokens"] = 0 + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 422 + + # Test temperature validation + valid_chat_request["max_tokens"] = 50 + valid_chat_request["temperature"] = 3.0 # Out of range + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 422 + + # Test n validation + valid_chat_request["temperature"] = 0.7 + valid_chat_request["n"] = 200 # Out of range + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 422 + + def test_chat_completions_optional_parameters(self, client): + """Test chat completion with various optional parameters.""" + request_data = { + "model": "gpt-4", + "messages": [{"role": "user", "content": "Test message"}], + "max_tokens": 100, + "temperature": 0.5, + "top_p": 0.9, + "presence_penalty": 0.1, + "frequency_penalty": 0.2, + "stop": ["\\n"], + "user": "test-user", + } + response = client.post("/v1/chat/completions", json=request_data) + assert response.status_code == 200 + + # Text completions tests + def test_completions_success(self, client, valid_completion_request): + """Test successful text completion request.""" + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 200 + + data = response.json() + assert data["object"] == "text_completion" + assert data["model"] == valid_completion_request["model"] + assert "id" in data + assert "created" in data + + # Check choices + assert "choices" in data + assert len(data["choices"]) == 1 + choice = data["choices"][0] + assert choice["index"] == 0 + assert choice["finish_reason"] == "stop" + assert "text" in choice + assert isinstance(choice["text"], str) + + # Check usage + assert "usage" in data + usage = data["usage"] + assert "prompt_tokens" in usage + assert "completion_tokens" in usage + assert "total_tokens" in usage + + def test_completions_list_prompt(self, client): + """Test text completion with list prompt.""" + request_data = { + "model": "text-davinci-003", + "prompt": ["First prompt", "Second prompt"], + "max_tokens": 10, + } + response = client.post("/v1/completions", json=request_data) + assert response.status_code == 200 + + data = response.json() + assert data["object"] == "text_completion" + + def test_completions_invalid_model(self, client, valid_completion_request): + """Test text completion with invalid model.""" + valid_completion_request["model"] = "non-existent-model" + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 400 + + def test_completions_multiple_choices(self, client, valid_completion_request): + """Test text completion with multiple choices.""" + valid_completion_request["n"] = 2 + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 200 + + data = response.json() + assert len(data["choices"]) == 2 + + def test_completions_parameter_validation(self, client, valid_completion_request): + """Test parameter validation for text completions.""" + # Test max_tokens validation + valid_completion_request["max_tokens"] = -1 + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 422 + + # Test temperature validation + valid_completion_request["max_tokens"] = 10 + valid_completion_request["temperature"] = -1.0 + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 422 + + def test_completions_optional_parameters(self, client): + """Test text completion with various optional parameters.""" + request_data = { + "model": "gpt-3.5-turbo", + "prompt": "Test prompt", + "max_tokens": 50, + "temperature": 0.8, + "top_p": 0.95, + "n": 1, + "logprobs": 1, + "echo": True, + "stop": ["\\n", "."], + "presence_penalty": -0.5, + "frequency_penalty": 0.3, + "best_of": 2, + "user": "test-user-2", + } + response = client.post("/v1/completions", json=request_data) + assert response.status_code == 200 + + # Helper function tests + def test_generate_id_default(self): + """Test ID generation with default prefix.""" + id1 = generate_id() + id2 = generate_id() + + assert id1.startswith("chatcmpl-") + assert id2.startswith("chatcmpl-") + assert id1 != id2 # Should be unique + assert len(id1) == len("chatcmpl-") + 8 # prefix + 8 hex chars + + def test_generate_id_custom_prefix(self): + """Test ID generation with custom prefix.""" + custom_id = generate_id("cmpl") + assert custom_id.startswith("cmpl-") + assert len(custom_id) == len("cmpl-") + 8 + + def test_calculate_tokens(self): + """Test token calculation function.""" + # Test basic calculation + assert calculate_tokens("") == 1 # Minimum 1 token + assert calculate_tokens("a") == 1 + assert calculate_tokens("abcd") == 1 + assert calculate_tokens("abcde") == 1 # 5 chars = 1 token (rounded down) + assert calculate_tokens("abcdefgh") == 2 # 8 chars = 2 tokens + + # Test longer text + long_text = "This is a longer text with multiple words and characters." + expected_tokens = max(1, len(long_text) // 4) + assert calculate_tokens(long_text) == expected_tokens + + def test_get_dummy_responses(self): + """Test dummy response generation functions.""" + chat_response = get_dummy_chat_response() + assert isinstance(chat_response, str) + assert len(chat_response) > 0 + + completion_response = get_dummy_completion_response() + assert isinstance(completion_response, str) + assert len(completion_response) > 0 + + # Edge cases and error handling + def test_missing_required_fields_chat(self, client): + """Test chat completion with missing required fields.""" + # Missing model + response = client.post("/v1/chat/completions", json={"messages": []}) + assert response.status_code == 422 + + # Missing messages + response = client.post("/v1/chat/completions", json={"model": "gpt-3.5-turbo"}) + assert response.status_code == 422 + + def test_missing_required_fields_completion(self, client): + """Test text completion with missing required fields.""" + # Missing model + response = client.post("/v1/completions", json={"prompt": "test"}) + assert response.status_code == 422 + + # Missing prompt + response = client.post("/v1/completions", json={"model": "gpt-3.5-turbo"}) + assert response.status_code == 422 + + def test_invalid_json(self, client): + """Test endpoints with invalid JSON.""" + response = client.post( + "/v1/chat/completions", + content="invalid json", + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 422 + + def test_empty_request_body(self, client): + """Test endpoints with empty request body.""" + response = client.post("/v1/chat/completions", json={}) + assert response.status_code == 422 + + response = client.post("/v1/completions", json={}) + assert response.status_code == 422 + + # Content validation tests + def test_chat_message_content_types(self, client): + """Test chat completion with different message content types.""" + # Test with multiple messages + request_data = { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ], + } + response = client.post("/v1/chat/completions", json=request_data) + assert response.status_code == 200 + + def test_response_structure_consistency(self, client, valid_chat_request): + """Test that response structure is consistent across calls.""" + response1 = client.post("/v1/chat/completions", json=valid_chat_request) + response2 = client.post("/v1/chat/completions", json=valid_chat_request) + + assert response1.status_code == 200 + assert response2.status_code == 200 + + data1 = response1.json() + data2 = response2.json() + + # Structure should be the same + assert set(data1.keys()) == set(data2.keys()) + assert data1["object"] == data2["object"] + assert data1["model"] == data2["model"] + + # IDs should be different + assert data1["id"] != data2["id"] + + def test_concurrent_requests(self, client, valid_chat_request): + """Test handling of concurrent requests.""" + import threading + import time + + results = [] + + def make_request(): + response = client.post("/v1/chat/completions", json=valid_chat_request) + results.append(response.status_code) + + # Create multiple threads + threads = [] + for _ in range(5): + thread = threading.Thread(target=make_request) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # All requests should be successful + assert all(status == 200 for status in results) + assert len(results) == 5 + + # Performance and load tests + def test_response_time_reasonable(self, client, valid_chat_request): + """Test that response times are reasonable.""" + start_time = time.time() + response = client.post("/v1/chat/completions", json=valid_chat_request) + end_time = time.time() + + assert response.status_code == 200 + assert (end_time - start_time) < 1.0 # Should respond within 1 second + + def test_large_prompt_handling(self, client): + """Test handling of large prompts.""" + large_prompt = "A" * 10000 # 10K characters + request_data = { + "model": "text-davinci-003", + "prompt": large_prompt, + "max_tokens": 10, + } + response = client.post("/v1/completions", json=request_data) + assert response.status_code == 200 + + data = response.json() + # Token calculation should handle large text + assert data["usage"]["prompt_tokens"] > 1000 + + # Mock and patch tests + @patch("mock_llm_server.mock_llm_server.get_dummy_chat_response") + def test_chat_response_mocking(self, mock_response, client, valid_chat_request): + """Test mocking of chat response generation.""" + mock_response.return_value = "Mocked response for testing" + + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 200 + + data = response.json() + assert data["choices"][0]["message"]["content"] == "Mocked response for testing" + mock_response.assert_called_once() + + @patch("time.time") + def test_timestamp_consistency(self, mock_time, client, valid_chat_request): + """Test that timestamps are generated correctly.""" + mock_time.return_value = 1234567890 + + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 200 + + data = response.json() + assert data["created"] == 1234567890 + + # Documentation and OpenAPI tests + def test_openapi_docs_available(self, client): + """Test that OpenAPI documentation is available.""" + response = client.get("/docs") + assert response.status_code == 200 + + response = client.get("/openapi.json") + assert response.status_code == 200 + + openapi_data = response.json() + assert "openapi" in openapi_data + assert "paths" in openapi_data + assert "/v1/models" in openapi_data["paths"] + assert "/v1/chat/completions" in openapi_data["paths"] + assert "/v1/completions" in openapi_data["paths"] From d9b73bee71e053b5254a8d0be107ad1dfa375417 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 17 Sep 2025 14:16:48 -0500 Subject: [PATCH 2/7] Refactor mock LLM, fix tests --- .../benchmark/mock_llm_server/api.py | 173 ++++++++ .../mock_llm_server/mock_llm_server.py | 406 ------------------ .../benchmark/mock_llm_server/models.py | 191 ++++++++ .../mock_llm_server/response_data.py | 79 ++++ .../benchmark/mock_llm_server/run_server.py | 7 +- tests/benchmark/test_mock_llm_server.py | 45 +- 6 files changed, 484 insertions(+), 417 deletions(-) create mode 100644 nemoguardrails/benchmark/mock_llm_server/api.py delete mode 100644 nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py create mode 100644 nemoguardrails/benchmark/mock_llm_server/models.py create mode 100644 nemoguardrails/benchmark/mock_llm_server/response_data.py diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py new file mode 100644 index 000000000..bca45b1df --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +from typing import Union + +from fastapi import FastAPI, HTTPException + +from nemoguardrails.benchmark.mock_llm_server.models import ( + ChatCompletionChoice, + ChatCompletionRequest, + ChatCompletionResponse, + CompletionChoice, + CompletionRequest, + CompletionResponse, + Message, + Model, + ModelsResponse, + Usage, +) +from nemoguardrails.benchmark.mock_llm_server.response_data import ( + DUMMY_MODELS, + calculate_tokens, + generate_id, + get_dummy_chat_response, + get_dummy_completion_response, +) + + +def _validate_request_model( + request: Union[CompletionRequest, ChatCompletionRequest], +) -> None: + """Check the Completion or Chat Completion `model` field is in our supported model list""" + available_models = set([model["id"] for model in DUMMY_MODELS]) + if request.model not in available_models: + raise HTTPException( + status_code=400, + detail=f"Model '{request.model}' not found. Available models: {available_models}", + ) + + +app = FastAPI( + title="Mock LLM Server", + description="OpenAI-compatible mock LLM server for testing and benchmarking", + version="0.0.1", +) + + +@app.get("/") +async def root(): + """Root endpoint with basic server information.""" + return { + "message": "Mock LLM Server", + "version": "0.0.1", + "description": "OpenAI-compatible mock LLM server for testing and benchmarking", + "endpoints": ["/v1/models", "/v1/chat/completions", "/v1/completions"], + } + + +@app.get("/v1/models", response_model=ModelsResponse) +async def list_models(): + """List available models.""" + return ModelsResponse( + object="list", data=[Model(**model) for model in DUMMY_MODELS] + ) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def chat_completions(request: ChatCompletionRequest) -> ChatCompletionResponse: + """Create a chat completion.""" + # Validate model exists + _validate_request_model(request) + + # Generate dummy response + response_content = get_dummy_chat_response() + + # Calculate token usage + prompt_text = " ".join([msg.content for msg in request.messages]) + prompt_tokens = calculate_tokens(prompt_text) + completion_tokens = calculate_tokens(response_content) + + # Create response + completion_id = generate_id("chatcmpl") + created_timestamp = int(time.time()) + + choices = [] + for i in range(request.n or 1): + choice = ChatCompletionChoice( + index=i, + message=Message(role="assistant", content=response_content), + finish_reason="stop", + ) + choices.append(choice) + + response = ChatCompletionResponse( + id=completion_id, + object="chat.completion", + created=created_timestamp, + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + return response + + +@app.post("/v1/completions", response_model=CompletionResponse) +async def completions(request: CompletionRequest) -> CompletionResponse: + """Create a text completion.""" + + # Validate model exists + _validate_request_model(request) + + # Handle prompt (can be string or list) + if isinstance(request.prompt, list): + prompt_text = " ".join(request.prompt) + else: + prompt_text = request.prompt + + # Generate dummy response + response_text = get_dummy_completion_response() + + # Calculate token usage + prompt_tokens = calculate_tokens(prompt_text) + completion_tokens = calculate_tokens(response_text) + + # Create response + completion_id = generate_id("cmpl") + created_timestamp = int(time.time()) + + choices = [] + for i in range(request.n or 1): + choice = CompletionChoice( + text=response_text, index=i, logprobs=None, finish_reason="stop" + ) + choices.append(choice) + + response = CompletionResponse( + id=completion_id, + object="text_completion", + created=created_timestamp, + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + return response + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy", "timestamp": int(time.time())} diff --git a/nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py b/nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py deleted file mode 100644 index 28725e724..000000000 --- a/nemoguardrails/benchmark/mock_llm_server/mock_llm_server.py +++ /dev/null @@ -1,406 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Mock LLM FastAPI Server with OpenAI-compatible interface. - -This server provides dummy implementations of OpenAI API endpoints for testing -and benchmarking purposes. -""" - -import time -import uuid -from typing import Any, Dict, List, Optional, Union - -import uvicorn -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel, Field - -app = FastAPI( - title="Mock LLM Server", - description="OpenAI-compatible mock LLM server for testing and benchmarking", - version="1.0.0", -) - - -# Pydantic Models for Request/Response validation - - -class Message(BaseModel): - """Chat message model.""" - - role: str = Field(..., description="The role of the message author") - content: str = Field(..., description="The content of the message") - name: Optional[str] = Field(None, description="The name of the author") - - -class ChatCompletionRequest(BaseModel): - """Chat completion request model.""" - - model: str = Field(..., description="ID of the model to use") - messages: List[Message] = Field( - ..., description="List of messages comprising the conversation" - ) - max_tokens: Optional[int] = Field( - None, description="Maximum number of tokens to generate", ge=1 - ) - temperature: Optional[float] = Field( - 1.0, description="Sampling temperature", ge=0.0, le=2.0 - ) - top_p: Optional[float] = Field( - 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 - ) - n: Optional[int] = Field( - 1, description="Number of completions to generate", ge=1, le=128 - ) - stream: Optional[bool] = Field( - False, description="Whether to stream back partial progress" - ) - stop: Optional[Union[str, List[str]]] = Field( - None, description="Sequences where the API will stop generating" - ) - presence_penalty: Optional[float] = Field( - 0.0, description="Presence penalty", ge=-2.0, le=2.0 - ) - frequency_penalty: Optional[float] = Field( - 0.0, description="Frequency penalty", ge=-2.0, le=2.0 - ) - logit_bias: Optional[Dict[str, float]] = Field( - None, description="Modify likelihood of specified tokens" - ) - user: Optional[str] = Field( - None, description="Unique identifier representing your end-user" - ) - - -class CompletionRequest(BaseModel): - """Text completion request model.""" - - model: str = Field(..., description="ID of the model to use") - prompt: Union[str, List[str]] = Field( - ..., description="The prompt(s) to generate completions for" - ) - max_tokens: Optional[int] = Field( - 16, description="Maximum number of tokens to generate", ge=1 - ) - temperature: Optional[float] = Field( - 1.0, description="Sampling temperature", ge=0.0, le=2.0 - ) - top_p: Optional[float] = Field( - 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 - ) - n: Optional[int] = Field( - 1, description="Number of completions to generate", ge=1, le=128 - ) - stream: Optional[bool] = Field( - False, description="Whether to stream back partial progress" - ) - logprobs: Optional[int] = Field( - None, description="Include log probabilities", ge=0, le=5 - ) - echo: Optional[bool] = Field( - False, description="Echo back the prompt in addition to completion" - ) - stop: Optional[Union[str, List[str]]] = Field( - None, description="Sequences where the API will stop generating" - ) - presence_penalty: Optional[float] = Field( - 0.0, description="Presence penalty", ge=-2.0, le=2.0 - ) - frequency_penalty: Optional[float] = Field( - 0.0, description="Frequency penalty", ge=-2.0, le=2.0 - ) - best_of: Optional[int] = Field( - 1, description="Number of completions to generate server-side", ge=1 - ) - logit_bias: Optional[Dict[str, float]] = Field( - None, description="Modify likelihood of specified tokens" - ) - user: Optional[str] = Field( - None, description="Unique identifier representing your end-user" - ) - - -class Usage(BaseModel): - """Token usage information.""" - - prompt_tokens: int = Field(..., description="Number of tokens in the prompt") - completion_tokens: int = Field( - ..., description="Number of tokens in the completion" - ) - total_tokens: int = Field(..., description="Total number of tokens used") - - -class ChatCompletionChoice(BaseModel): - """Chat completion choice.""" - - index: int = Field(..., description="The index of this choice") - message: Message = Field(..., description="The generated message") - finish_reason: str = Field( - ..., description="The reason the model stopped generating" - ) - - -class CompletionChoice(BaseModel): - """Text completion choice.""" - - text: str = Field(..., description="The generated text") - index: int = Field(..., description="The index of this choice") - logprobs: Optional[Dict[str, Any]] = Field( - None, description="Log probability information" - ) - finish_reason: str = Field( - ..., description="The reason the model stopped generating" - ) - - -class ChatCompletionResponse(BaseModel): - """Chat completion response.""" - - id: str = Field(..., description="Unique identifier for the completion") - object: str = Field("chat.completion", description="Object type") - created: int = Field( - ..., description="Unix timestamp when the completion was created" - ) - model: str = Field(..., description="The model used for completion") - choices: List[ChatCompletionChoice] = Field( - ..., description="List of completion choices" - ) - usage: Usage = Field(..., description="Token usage information") - - -class CompletionResponse(BaseModel): - """Text completion response.""" - - id: str = Field(..., description="Unique identifier for the completion") - object: str = Field("text_completion", description="Object type") - created: int = Field( - ..., description="Unix timestamp when the completion was created" - ) - model: str = Field(..., description="The model used for completion") - choices: List[CompletionChoice] = Field( - ..., description="List of completion choices" - ) - usage: Usage = Field(..., description="Token usage information") - - -class Model(BaseModel): - """Model information.""" - - id: str = Field(..., description="Model identifier") - object: str = Field("model", description="Object type") - created: int = Field(..., description="Unix timestamp when the model was created") - owned_by: str = Field(..., description="Organization that owns the model") - - -class ModelsResponse(BaseModel): - """Models list response.""" - - object: str = Field("list", description="Object type") - data: List[Model] = Field(..., description="List of available models") - - -# Dummy data and helper functions - -DUMMY_MODELS = [ - { - "id": "gpt-3.5-turbo", - "object": "model", - "created": 1677610602, - "owned_by": "openai", - }, - {"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"}, - { - "id": "gpt-4-turbo", - "object": "model", - "created": 1712361441, - "owned_by": "openai", - }, - { - "id": "text-davinci-003", - "object": "model", - "created": 1669599635, - "owned_by": "openai", - }, -] - -DUMMY_CHAT_RESPONSES = [ - "This is a mock response from the LLM server.", - "I'm a dummy AI assistant created for testing purposes.", - "This response is generated by a mock OpenAI-compatible server.", - "Hello! I'm responding with dummy data for benchmarking.", - "This is a simulated conversation response for testing.", -] - -DUMMY_COMPLETION_RESPONSES = [ - " This is a dummy text completion.", - " Here's some mock generated text.", - " This is a sample completion response.", - " Mock completion text for testing purposes.", - " Dummy text generated by the mock server.", -] - - -def generate_id(prefix: str = "chatcmpl") -> str: - """Generate a unique ID for completions.""" - return f"{prefix}-{uuid.uuid4().hex[:8]}" - - -def calculate_tokens(text: str) -> int: - """Rough token calculation (approximately 4 characters per token).""" - return max(1, len(text) // 4) - - -def get_dummy_chat_response() -> str: - """Get a dummy chat response.""" - import random - - return random.choice(DUMMY_CHAT_RESPONSES) - - -def get_dummy_completion_response() -> str: - """Get a dummy completion response.""" - import random - - return random.choice(DUMMY_COMPLETION_RESPONSES) - - -# API Endpoints - - -@app.get("/") -async def root(): - """Root endpoint with basic server information.""" - return { - "message": "Mock LLM Server", - "version": "1.0.0", - "description": "OpenAI-compatible mock LLM server for testing and benchmarking", - "endpoints": ["/v1/models", "/v1/chat/completions", "/v1/completions"], - } - - -@app.get("/v1/models", response_model=ModelsResponse) -async def list_models(): - """List available models.""" - return ModelsResponse( - object="list", data=[Model(**model) for model in DUMMY_MODELS] - ) - - -@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) -async def chat_completions(request: ChatCompletionRequest): - """Create a chat completion.""" - # Validate model exists - available_models = [model["id"] for model in DUMMY_MODELS] - if request.model not in available_models: - raise HTTPException( - status_code=400, - detail=f"Model '{request.model}' not found. Available models: {available_models}", - ) - - # Generate dummy response - response_content = get_dummy_chat_response() - - # Calculate token usage - prompt_text = " ".join([msg.content for msg in request.messages]) - prompt_tokens = calculate_tokens(prompt_text) - completion_tokens = calculate_tokens(response_content) - - # Create response - completion_id = generate_id("chatcmpl") - created_timestamp = int(time.time()) - - choices = [] - for i in range(request.n or 1): - choice = ChatCompletionChoice( - index=i, - message=Message(role="assistant", content=response_content, name=None), - finish_reason="stop", - ) - choices.append(choice) - - return ChatCompletionResponse( - id=completion_id, - object="chat.completion", - created=created_timestamp, - model=request.model, - choices=choices, - usage=Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - - -@app.post("/v1/completions", response_model=CompletionResponse) -async def completions(request: CompletionRequest): - """Create a text completion.""" - # Validate model exists - available_models = [model["id"] for model in DUMMY_MODELS] - if request.model not in available_models: - raise HTTPException( - status_code=400, - detail=f"Model '{request.model}' not found. Available models: {available_models}", - ) - - # Handle prompt (can be string or list) - if isinstance(request.prompt, list): - prompt_text = " ".join(request.prompt) - else: - prompt_text = request.prompt - - # Generate dummy response - response_text = get_dummy_completion_response() - - # Calculate token usage - prompt_tokens = calculate_tokens(prompt_text) - completion_tokens = calculate_tokens(response_text) - - # Create response - completion_id = generate_id("cmpl") - created_timestamp = int(time.time()) - - choices = [] - for i in range(request.n or 1): - choice = CompletionChoice( - text=response_text, index=i, logprobs=None, finish_reason="stop" - ) - choices.append(choice) - - return CompletionResponse( - id=completion_id, - object="text_completion", - created=created_timestamp, - model=request.model, - choices=choices, - usage=Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - - -@app.get("/health") -async def health_check(): - """Health check endpoint.""" - return {"status": "healthy", "timestamp": int(time.time())} - - -if __name__ == "__main__": - uvicorn.run( - "mock_llm_server:app", host="0.0.0.0", port=8000, reload=True, log_level="info" - ) diff --git a/nemoguardrails/benchmark/mock_llm_server/models.py b/nemoguardrails/benchmark/mock_llm_server/models.py new file mode 100644 index 000000000..8634c46a6 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/models.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Union + +from pydantic import BaseModel, Field + + +class Message(BaseModel): + """Chat message model.""" + + role: str = Field(..., description="The role of the message author") + content: str = Field(..., description="The content of the message") + + +class ChatCompletionRequest(BaseModel): + """Chat completion request model.""" + + model: str = Field(..., description="ID of the model to use") + messages: list[Message] = Field( + ..., description="List of messages comprising the conversation" + ) + max_tokens: Optional[int] = Field( + None, description="Maximum number of tokens to generate", ge=1 + ) + temperature: Optional[float] = Field( + 1.0, description="Sampling temperature", ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 + ) + n: Optional[int] = Field( + 1, description="Number of completions to generate", ge=1, le=128 + ) + stream: Optional[bool] = Field( + False, description="Whether to stream back partial progress" + ) + stop: Optional[Union[str, list[str]]] = Field( + None, description="Sequences where the API will stop generating" + ) + presence_penalty: Optional[float] = Field( + 0.0, description="Presence penalty", ge=-2.0, le=2.0 + ) + frequency_penalty: Optional[float] = Field( + 0.0, description="Frequency penalty", ge=-2.0, le=2.0 + ) + logit_bias: Optional[dict[str, float]] = Field( + None, description="Modify likelihood of specified tokens" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + + +class CompletionRequest(BaseModel): + """Text completion request model.""" + + model: str = Field(..., description="ID of the model to use") + prompt: Union[str, list[str]] = Field( + ..., description="The prompt(s) to generate completions for" + ) + max_tokens: Optional[int] = Field( + 16, description="Maximum number of tokens to generate", ge=1 + ) + temperature: Optional[float] = Field( + 1.0, description="Sampling temperature", ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 + ) + n: Optional[int] = Field( + 1, description="Number of completions to generate", ge=1, le=128 + ) + stream: Optional[bool] = Field( + False, description="Whether to stream back partial progress" + ) + logprobs: Optional[int] = Field( + None, description="Include log probabilities", ge=0, le=5 + ) + echo: Optional[bool] = Field( + False, description="Echo back the prompt in addition to completion" + ) + stop: Optional[Union[str, list[str]]] = Field( + None, description="Sequences where the API will stop generating" + ) + presence_penalty: Optional[float] = Field( + 0.0, description="Presence penalty", ge=-2.0, le=2.0 + ) + frequency_penalty: Optional[float] = Field( + 0.0, description="Frequency penalty", ge=-2.0, le=2.0 + ) + best_of: Optional[int] = Field( + 1, description="Number of completions to generate server-side", ge=1 + ) + logit_bias: Optional[dict[str, float]] = Field( + None, description="Modify likelihood of specified tokens" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + + +class Usage(BaseModel): + """Token usage information.""" + + prompt_tokens: int = Field(..., description="Number of tokens in the prompt") + completion_tokens: int = Field( + ..., description="Number of tokens in the completion" + ) + total_tokens: int = Field(..., description="Total number of tokens used") + + +class ChatCompletionChoice(BaseModel): + """Chat completion choice.""" + + index: int = Field(..., description="The index of this choice") + message: Message = Field(..., description="The generated message") + finish_reason: str = Field( + ..., description="The reason the model stopped generating" + ) + + +class CompletionChoice(BaseModel): + """Text completion choice.""" + + text: str = Field(..., description="The generated text") + index: int = Field(..., description="The index of this choice") + logprobs: Optional[dict[str, Any]] = Field( + None, description="Log probability information" + ) + finish_reason: str = Field( + ..., description="The reason the model stopped generating" + ) + + +class ChatCompletionResponse(BaseModel): + """Chat completion response.""" + + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("chat.completion", description="Object type") + created: int = Field( + ..., description="Unix timestamp when the completion was created" + ) + model: str = Field(..., description="The model used for completion") + choices: list[ChatCompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: Usage = Field(..., description="Token usage information") + + +class CompletionResponse(BaseModel): + """Text completion response.""" + + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("text_completion", description="Object type") + created: int = Field( + ..., description="Unix timestamp when the completion was created" + ) + model: str = Field(..., description="The model used for completion") + choices: list[CompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: Usage = Field(..., description="Token usage information") + + +class Model(BaseModel): + """Model information.""" + + id: str = Field(..., description="Model identifier") + object: str = Field("model", description="Object type") + created: int = Field(..., description="Unix timestamp when the model was created") + owned_by: str = Field(..., description="Organization that owns the model") + + +class ModelsResponse(BaseModel): + """Models list response.""" + + object: str = Field("list", description="Object type") + data: list[Model] = Field(..., description="List of available models") diff --git a/nemoguardrails/benchmark/mock_llm_server/response_data.py b/nemoguardrails/benchmark/mock_llm_server/response_data.py new file mode 100644 index 000000000..7e3c7e760 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/response_data.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import uuid + +DUMMY_MODELS = [ + { + "id": "gpt-3.5-turbo", + "object": "model", + "created": 1677610602, + "owned_by": "openai", + }, + {"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"}, + { + "id": "gpt-4-turbo", + "object": "model", + "created": 1712361441, + "owned_by": "openai", + }, + { + "id": "text-davinci-003", + "object": "model", + "created": 1669599635, + "owned_by": "openai", + }, +] + +DUMMY_CHAT_RESPONSES = [ + "This is a mock response from the LLM server.", + "I'm a dummy AI assistant created for testing purposes.", + "This response is generated by a mock OpenAI-compatible server.", + "Hello! I'm responding with dummy data for benchmarking.", + "This is a simulated conversation response for testing.", +] + +DUMMY_COMPLETION_RESPONSES = [ + " This is a dummy text completion.", + " Here's some mock generated text.", + " This is a sample completion response.", + " Mock completion text for testing purposes.", + " Dummy text generated by the mock server.", +] + + +def generate_id(prefix: str = "chatcmpl") -> str: + """Generate a unique ID for completions.""" + return f"{prefix}-{uuid.uuid4().hex[:8]}" + + +def calculate_tokens(text: str) -> int: + """Rough token calculation (approximately 4 characters per token).""" + return max(1, len(text) // 4) + + +def get_dummy_chat_response() -> str: + """Get a dummy chat response.""" + import random + + return random.choice(DUMMY_CHAT_RESPONSES) + + +def get_dummy_completion_response() -> str: + """Get a dummy completion response.""" + import random + + return random.choice(DUMMY_COMPLETION_RESPONSES) diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py index 66f281932..83e68c049 100644 --- a/nemoguardrails/benchmark/mock_llm_server/run_server.py +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -25,9 +25,10 @@ import sys import uvicorn +from api import app -# Add the current directory to Python path to import the server module -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +# # Add the current directory to Python path to import the server module +# sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) def main(): @@ -62,7 +63,7 @@ def main(): try: uvicorn.run( - "mock_llm_server:app", + app=app, host=args.host, port=args.port, reload=args.reload, diff --git a/tests/benchmark/test_mock_llm_server.py b/tests/benchmark/test_mock_llm_server.py index b74f9633d..ad24e5303 100644 --- a/tests/benchmark/test_mock_llm_server.py +++ b/tests/benchmark/test_mock_llm_server.py @@ -28,16 +28,27 @@ import pytest from fastapi.testclient import TestClient -# Import the server and its components -from mock_llm_server.mock_llm_server import ( +from nemoguardrails.benchmark.mock_llm_server.api import app +from nemoguardrails.benchmark.mock_llm_server.response_data import ( + DUMMY_CHAT_RESPONSES, DUMMY_MODELS, - app, calculate_tokens, generate_id, get_dummy_chat_response, get_dummy_completion_response, ) +# +# # Import the server and its components +# from mock_llm_server.mock_llm_server import ( +# DUMMY_MODELS, +# app, +# calculate_tokens, +# generate_id, +# get_dummy_chat_response, +# get_dummy_completion_response, +# ) + class TestMockLLMServer: """Test class for the Mock LLM Server.""" @@ -75,7 +86,7 @@ def test_root_endpoint(self, client): data = response.json() assert data["message"] == "Mock LLM Server" - assert data["version"] == "1.0.0" + assert data["version"] == "0.0.1" assert "description" in data assert "/v1/models" in data["endpoints"] assert "/v1/chat/completions" in data["endpoints"] @@ -491,16 +502,34 @@ def test_large_prompt_handling(self, client): assert data["usage"]["prompt_tokens"] > 1000 # Mock and patch tests - @patch("mock_llm_server.mock_llm_server.get_dummy_chat_response") - def test_chat_response_mocking(self, mock_response, client, valid_chat_request): + @patch("nemoguardrails.benchmark.mock_llm_server.api.get_dummy_chat_response") + def test_chat_completion_response_mocking( + self, mock_response, client, valid_chat_request + ): """Test mocking of chat response generation.""" - mock_response.return_value = "Mocked response for testing" + expected_response = "Mocked response for testing chat completions" + mock_response.return_value = expected_response response = client.post("/v1/chat/completions", json=valid_chat_request) assert response.status_code == 200 data = response.json() - assert data["choices"][0]["message"]["content"] == "Mocked response for testing" + assert data["choices"][0]["message"]["content"] == expected_response + mock_response.assert_called_once() + + @patch("nemoguardrails.benchmark.mock_llm_server.api.get_dummy_completion_response") + def test_completion_response_mocking( + self, mock_response, client, valid_completion_request + ): + """Test mocking of chat response generation.""" + expected_response = "Mocked response to check completion responses" + mock_response.return_value = expected_response + + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 200 + + data = response.json() + assert data["choices"][0]["text"] == expected_response mock_response.assert_called_once() @patch("time.time") From 9021b81ab121fe24112903d6676e5b7c981173d4 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 17 Sep 2025 16:59:41 -0500 Subject: [PATCH 3/7] Added tests to load YAML config. Still debugging dependency-injection of this into endpoints --- .../benchmark/mock_llm_server/api.py | 12 +++- .../benchmark/mock_llm_server/config.py | 72 +++++++++++++++++++ ...llama-3.1-nemoguard-8b-content-safety.yaml | 12 ++++ .../benchmark/mock_llm_server/run_server.py | 13 +++- tests/benchmark/mock_model_config.yaml | 3 + tests/benchmark/test_mock_llm_server.py | 45 +++++++++--- 6 files changed, 140 insertions(+), 17 deletions(-) create mode 100644 nemoguardrails/benchmark/mock_llm_server/config.py create mode 100644 nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml create mode 100644 tests/benchmark/mock_model_config.yaml diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py index bca45b1df..ca92ab193 100644 --- a/nemoguardrails/benchmark/mock_llm_server/api.py +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -15,10 +15,11 @@ import time -from typing import Union +from typing import Annotated, Union -from fastapi import FastAPI, HTTPException +from fastapi import Depends, FastAPI, HTTPException +from nemoguardrails.benchmark.mock_llm_server.config import AppModelConfig, get_config from nemoguardrails.benchmark.mock_llm_server.models import ( ChatCompletionChoice, ChatCompletionRequest, @@ -59,14 +60,19 @@ def _validate_request_model( ) +ModelConfigDep = Annotated[AppModelConfig, Depends(get_config)] + + @app.get("/") -async def root(): +async def root(current_config: ModelConfigDep): """Root endpoint with basic server information.""" + print(current_config) return { "message": "Mock LLM Server", "version": "0.0.1", "description": "OpenAI-compatible mock LLM server for testing and benchmarking", "endpoints": ["/v1/models", "/v1/chat/completions", "/v1/completions"], + "model_configuration": current_config, } diff --git a/nemoguardrails/benchmark/mock_llm_server/config.py b/nemoguardrails/benchmark/mock_llm_server/config.py new file mode 100644 index 000000000..0b2fa42e6 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/config.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import lru_cache +from typing import Any, Optional, Union + +import yaml +from openai._utils import lru_cache +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class AppModelConfig(BaseModel): + """Pydantic model to configure the Mock LLM Server.""" + + # Mandatory fields + model: str = Field(..., description="Model name served by mock server") + refusal_text: str = Field(..., description="Refusal response text") + + # Config with default values + refusal_probability: float = Field( + default=0.1, description="Probability of refusal (between 0 and 1)" + ) + # Latency sampled from a truncated-normal distribution. + # Plain Normal distributions have infinite support, and can be negative + latency_min_seconds: float = Field( + default=0.1, description="Minimum latency in seconds" + ) + latency_max_seconds: float = Field( + default=5, description="Maximum latency in seconds" + ) + latency_mean_seconds: float = Field( + default=0.5, description="The average response time in seconds" + ) + latency_std_seconds: float = Field( + default=0.1, description="Standard deviation of response time" + ) + + +settings: Optional[AppModelConfig] = None + + +def load_config(yaml_file: str) -> None: + """Load the Model configuration from YAML file, store in global `settings` var""" + global settings + with open(yaml_file, "r") as f: + config_data = yaml.safe_load(f) + settings = AppModelConfig(**config_data) + + +@lru_cache +def get_config() -> AppModelConfig: + """FastAPI Dependency to inject model configuration""" + print(f"get_config called, settings = {settings}") + print(f"GET_CONFIG CALLED IN PROCESS ID: {os.getpid()}") + + if settings is None: + raise RuntimeError("No configuration loaded") + return settings diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml new file mode 100644 index 000000000..2eebb1063 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml @@ -0,0 +1,12 @@ +model: "nvidia/llama-3.1-nemoguard-8b-content-safety" +refusal_probability: 0.01 +refusal_text: | + { + "User Safety": "unsafe", + "Response Safety": "unsafe", + "Safety Categories": "PII/Privacy" + } +latency_min_seconds: 0.1 +latency_max_seconds: 5 +latency_mean_seconds: 0.4 +latency_std_seconds: 0.1 diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py index 83e68c049..d3732c97b 100644 --- a/nemoguardrails/benchmark/mock_llm_server/run_server.py +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -26,9 +26,7 @@ import uvicorn from api import app - -# # Add the current directory to Python path to import the server module -# sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from config import get_config, load_config, settings def main(): @@ -54,11 +52,20 @@ def main(): help="Log level (default: info)", ) + parser.add_argument( + "--config-file", help="YAML file to configure model", required=True + ) + args = parser.parse_args() + # Load model configuration + load_config(args.config_file) + model_config = get_config() + print(f"Starting Mock LLM Server on {args.host}:{args.port}") print(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs") print(f"Health check at: http://{args.host}:{args.port}/health") + print(f"Model configuration: {model_config}") print("Press Ctrl+C to stop the server") try: diff --git a/tests/benchmark/mock_model_config.yaml b/tests/benchmark/mock_model_config.yaml new file mode 100644 index 000000000..384a988e5 --- /dev/null +++ b/tests/benchmark/mock_model_config.yaml @@ -0,0 +1,3 @@ +model: "mock_model" +refusal_probability: 0.01 +refusal_text: "I'm sorry, I can't help you with that request" diff --git a/tests/benchmark/test_mock_llm_server.py b/tests/benchmark/test_mock_llm_server.py index ad24e5303..0bc54423e 100644 --- a/tests/benchmark/test_mock_llm_server.py +++ b/tests/benchmark/test_mock_llm_server.py @@ -21,6 +21,7 @@ """ import json +import os import time from typing import Any, Dict, List from unittest.mock import patch @@ -29,6 +30,11 @@ from fastapi.testclient import TestClient from nemoguardrails.benchmark.mock_llm_server.api import app +from nemoguardrails.benchmark.mock_llm_server.config import ( + AppModelConfig, + get_config, + load_config, +) from nemoguardrails.benchmark.mock_llm_server.response_data import ( DUMMY_CHAT_RESPONSES, DUMMY_MODELS, @@ -38,17 +44,6 @@ get_dummy_completion_response, ) -# -# # Import the server and its components -# from mock_llm_server.mock_llm_server import ( -# DUMMY_MODELS, -# app, -# calculate_tokens, -# generate_id, -# get_dummy_chat_response, -# get_dummy_completion_response, -# ) - class TestMockLLMServer: """Test class for the Mock LLM Server.""" @@ -81,6 +76,17 @@ def valid_completion_request(self): # Root endpoint tests def test_root_endpoint(self, client): """Test the root endpoint returns correct information.""" + + mock_config = AppModelConfig( + model="mock_config_model_name", + refusal_text="I'm afraid I can't do that, Dave", + ) + + def override_get_config(): + return mock_config + + app.dependency_overrides[get_config] = override_get_config + response = client.get("/") assert response.status_code == 200 @@ -91,6 +97,8 @@ def test_root_endpoint(self, client): assert "/v1/models" in data["endpoints"] assert "/v1/chat/completions" in data["endpoints"] assert "/v1/completions" in data["endpoints"] + assert data["model_configuration"]["model"] == mock_config.model + assert data["model_configuration"]["refusal_text"] == mock_config.refusal_text # Health check tests def test_health_check(self, client): @@ -558,3 +566,18 @@ def test_openapi_docs_available(self, client): assert "/v1/models" in openapi_data["paths"] assert "/v1/chat/completions" in openapi_data["paths"] assert "/v1/completions" in openapi_data["paths"] + + def test_read_root_with_mock_config(self): + """Tests load_config method correctly populates the `settings` global variable""" + yaml_file = os.path.join(os.path.dirname(__file__), "mock_model_config.yaml") + + # Make sure settings is empty to start with, load and check it's populated + load_config(yaml_file) + config = get_config() + assert config is not None + + # Now check the contents against `mock_model_config.yaml` + assert isinstance(config, AppModelConfig) + assert config.model == "mock_model" + assert config.refusal_probability == 0.01 + assert config.refusal_text == "I'm sorry, I can't help you with that request" From 687e33bce7384bc67532fe206dcc266e1f3eca9f Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 17 Sep 2025 17:08:05 -0500 Subject: [PATCH 4/7] Move FastAPI app import **after** the dependencies are loaded and cached --- nemoguardrails/benchmark/mock_llm_server/config.py | 2 -- nemoguardrails/benchmark/mock_llm_server/run_server.py | 8 +++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nemoguardrails/benchmark/mock_llm_server/config.py b/nemoguardrails/benchmark/mock_llm_server/config.py index 0b2fa42e6..9945cb98a 100644 --- a/nemoguardrails/benchmark/mock_llm_server/config.py +++ b/nemoguardrails/benchmark/mock_llm_server/config.py @@ -18,7 +18,6 @@ from typing import Any, Optional, Union import yaml -from openai._utils import lru_cache from pydantic import BaseModel, Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -61,7 +60,6 @@ def load_config(yaml_file: str) -> None: settings = AppModelConfig(**config_data) -@lru_cache def get_config() -> AppModelConfig: """FastAPI Dependency to inject model configuration""" print(f"get_config called, settings = {settings}") diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py index d3732c97b..0d05756d2 100644 --- a/nemoguardrails/benchmark/mock_llm_server/run_server.py +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -21,12 +21,11 @@ """ import argparse -import os import sys import uvicorn -from api import app -from config import get_config, load_config, settings + +from nemoguardrails.benchmark.mock_llm_server.config import get_config, load_config def main(): @@ -62,6 +61,9 @@ def main(): load_config(args.config_file) model_config = get_config() + # Import the app after configuration is loaded. This caches the values in the app Dependencies + from nemoguardrails.benchmark.mock_llm_server.api import app + print(f"Starting Mock LLM Server on {args.host}:{args.port}") print(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs") print(f"Health check at: http://{args.host}:{args.port}/health") From c0afd8d5096eb0c3ebfc8960b55c23c5e32c63e3 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Wed, 17 Sep 2025 21:06:48 -0500 Subject: [PATCH 5/7] Remove debugging print statements --- nemoguardrails/benchmark/mock_llm_server/api.py | 1 - nemoguardrails/benchmark/mock_llm_server/config.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py index ca92ab193..c34816ef2 100644 --- a/nemoguardrails/benchmark/mock_llm_server/api.py +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -66,7 +66,6 @@ def _validate_request_model( @app.get("/") async def root(current_config: ModelConfigDep): """Root endpoint with basic server information.""" - print(current_config) return { "message": "Mock LLM Server", "version": "0.0.1", diff --git a/nemoguardrails/benchmark/mock_llm_server/config.py b/nemoguardrails/benchmark/mock_llm_server/config.py index 9945cb98a..0f1abe7bb 100644 --- a/nemoguardrails/benchmark/mock_llm_server/config.py +++ b/nemoguardrails/benchmark/mock_llm_server/config.py @@ -62,9 +62,6 @@ def load_config(yaml_file: str) -> None: def get_config() -> AppModelConfig: """FastAPI Dependency to inject model configuration""" - print(f"get_config called, settings = {settings}") - print(f"GET_CONFIG CALLED IN PROCESS ID: {os.getpid()}") - if settings is None: raise RuntimeError("No configuration loaded") return settings From e62f39421f7d60afb2eba05e0563beb03e38a6e8 Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Thu, 18 Sep 2025 15:48:31 -0500 Subject: [PATCH 6/7] Temporary checkin --- .../benchmark/mock_llm_server/api.py | 14 ++++-- .../mock_llm_server/response_data.py | 49 +++++++++++++++++-- tests/benchmark/test_mock_llm_server.py | 6 +++ 3 files changed, 60 insertions(+), 9 deletions(-) diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py index c34816ef2..a33b7505e 100644 --- a/nemoguardrails/benchmark/mock_llm_server/api.py +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -15,7 +15,7 @@ import time -from typing import Annotated, Union +from typing import Annotated, Optional, Union from fastapi import Depends, FastAPI, HTTPException @@ -84,13 +84,15 @@ async def list_models(): @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) -async def chat_completions(request: ChatCompletionRequest) -> ChatCompletionResponse: +async def chat_completions( + request: ChatCompletionRequest, config: ModelConfigDep +) -> ChatCompletionResponse: """Create a chat completion.""" # Validate model exists _validate_request_model(request) # Generate dummy response - response_content = get_dummy_chat_response() + response_content = get_dummy_chat_response(config) # Calculate token usage prompt_text = " ".join([msg.content for msg in request.messages]) @@ -127,7 +129,9 @@ async def chat_completions(request: ChatCompletionRequest) -> ChatCompletionResp @app.post("/v1/completions", response_model=CompletionResponse) -async def completions(request: CompletionRequest) -> CompletionResponse: +async def completions( + request: CompletionRequest, config: ModelConfigDep +) -> CompletionResponse: """Create a text completion.""" # Validate model exists @@ -140,7 +144,7 @@ async def completions(request: CompletionRequest) -> CompletionResponse: prompt_text = request.prompt # Generate dummy response - response_text = get_dummy_completion_response() + response_text = get_dummy_completion_response(config) # Calculate token usage prompt_tokens = calculate_tokens(prompt_text) diff --git a/nemoguardrails/benchmark/mock_llm_server/response_data.py b/nemoguardrails/benchmark/mock_llm_server/response_data.py index 7e3c7e760..abd1bc77c 100644 --- a/nemoguardrails/benchmark/mock_llm_server/response_data.py +++ b/nemoguardrails/benchmark/mock_llm_server/response_data.py @@ -14,7 +14,13 @@ # limitations under the License. +import random import uuid +from typing import Optional + +import numpy as np + +from nemoguardrails.benchmark.mock_llm_server.config import AppModelConfig, get_config DUMMY_MODELS = [ { @@ -65,15 +71,50 @@ def calculate_tokens(text: str) -> int: return max(1, len(text) // 4) -def get_dummy_chat_response() -> str: +def get_dummy_chat_response(config: AppModelConfig) -> str: """Get a dummy chat response.""" - import random + + if is_refusal(config): + return config.refusal_text return random.choice(DUMMY_CHAT_RESPONSES) -def get_dummy_completion_response() -> str: +def get_dummy_completion_response(config: AppModelConfig) -> str: """Get a dummy completion response.""" - import random + if is_refusal(config): + return config.refusal_text return random.choice(DUMMY_COMPLETION_RESPONSES) + + +def get_latency_seconds(config: AppModelConfig, seed: Optional[int] = None) -> float: + """Sample latency for this request using the model's config + Very inefficient to generate each sample singly rather than in batch + """ + if seed: + np.random.seed(seed) + + # Sample from the normal distribution using model config + latency_seconds = np.random.normal( + loc=config.latency_mean_seconds, scale=config.latency_std_seconds, size=1 + ) + + # Truncate distribution's support using min and max config values + latency_seconds = np.clip( + latency_seconds, + a_min=config.latency_min_seconds, + a_max=config.latency_max_seconds, + ) + return float(latency_seconds) + + +def is_refusal(config: AppModelConfig, seed: Optional[int] = None) -> bool: + """Check if the model should return a refusal + Very inefficient to generate each sample singly rather than in batch + """ + if seed: + np.random.seed(seed) + + refusal = np.random.binomial(n=1, p=config.refusal_probability, size=1) + return bool(refusal[0]) diff --git a/tests/benchmark/test_mock_llm_server.py b/tests/benchmark/test_mock_llm_server.py index 0bc54423e..c53816de1 100644 --- a/tests/benchmark/test_mock_llm_server.py +++ b/tests/benchmark/test_mock_llm_server.py @@ -581,3 +581,9 @@ def test_read_root_with_mock_config(self): assert config.model == "mock_model" assert config.refusal_probability == 0.01 assert config.refusal_text == "I'm sorry, I can't help you with that request" + + @patch("nemoguardrails.benchmark.mock_llm_server.config.settings", None) + def test_get_config_raises_exception(self): + """Check if we call `get_config()` without settings set we raise an exception""" + with pytest.raises(RuntimeError, match="No configuration loaded"): + get_config() From 6ddcacac53109a6b6cf3125addfa9cab04accfec Mon Sep 17 00:00:00 2001 From: tgasser-nv <200644301+tgasser-nv@users.noreply.github.com> Date: Mon, 22 Sep 2025 11:25:44 -0500 Subject: [PATCH 7/7] Add refusal probability and tests to check it --- .../mock_llm_server/response_data.py | 20 +++++---- tests/benchmark/test_mock_llm_server.py | 43 +++++++++++++++---- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/nemoguardrails/benchmark/mock_llm_server/response_data.py b/nemoguardrails/benchmark/mock_llm_server/response_data.py index abd1bc77c..38522583a 100644 --- a/nemoguardrails/benchmark/mock_llm_server/response_data.py +++ b/nemoguardrails/benchmark/mock_llm_server/response_data.py @@ -53,11 +53,11 @@ ] DUMMY_COMPLETION_RESPONSES = [ - " This is a dummy text completion.", - " Here's some mock generated text.", - " This is a sample completion response.", - " Mock completion text for testing purposes.", - " Dummy text generated by the mock server.", + "This is a dummy text completion.", + "Here's some mock generated text.", + "This is a sample completion response.", + "Mock completion text for testing purposes.", + "Dummy text generated by the mock server.", ] @@ -71,18 +71,20 @@ def calculate_tokens(text: str) -> int: return max(1, len(text) // 4) -def get_dummy_chat_response(config: AppModelConfig) -> str: +def get_dummy_chat_response(config: AppModelConfig, seed: Optional[int] = None) -> str: """Get a dummy chat response.""" - if is_refusal(config): + if is_refusal(config, seed): return config.refusal_text return random.choice(DUMMY_CHAT_RESPONSES) -def get_dummy_completion_response(config: AppModelConfig) -> str: +def get_dummy_completion_response( + config: AppModelConfig, seed: Optional[int] = None +) -> str: """Get a dummy completion response.""" - if is_refusal(config): + if is_refusal(config, seed): return config.refusal_text return random.choice(DUMMY_COMPLETION_RESPONSES) diff --git a/tests/benchmark/test_mock_llm_server.py b/tests/benchmark/test_mock_llm_server.py index c53816de1..552eb57e1 100644 --- a/tests/benchmark/test_mock_llm_server.py +++ b/tests/benchmark/test_mock_llm_server.py @@ -37,6 +37,7 @@ ) from nemoguardrails.benchmark.mock_llm_server.response_data import ( DUMMY_CHAT_RESPONSES, + DUMMY_COMPLETION_RESPONSES, DUMMY_MODELS, calculate_tokens, generate_id, @@ -44,6 +45,20 @@ get_dummy_completion_response, ) +RANDOM_SEED = 12345 +REFUSAL_TEXT = "I'm sorry Dave, I'm afraid I can't do that" +NO_REFUSAL_CONFIG = AppModelConfig( + model="mock-model", + refusal_text=REFUSAL_TEXT, + refusal_probability=0.0, +) + +ALL_REFUSAL_CONFIG = AppModelConfig( + model="mock-model", + refusal_text=REFUSAL_TEXT, + refusal_probability=1.0, +) + class TestMockLLMServer: """Test class for the Mock LLM Server.""" @@ -375,15 +390,25 @@ def test_calculate_tokens(self): expected_tokens = max(1, len(long_text) // 4) assert calculate_tokens(long_text) == expected_tokens - def test_get_dummy_responses(self): - """Test dummy response generation functions.""" - chat_response = get_dummy_chat_response() - assert isinstance(chat_response, str) - assert len(chat_response) > 0 - - completion_response = get_dummy_completion_response() - assert isinstance(completion_response, str) - assert len(completion_response) > 0 + def test_get_dummy_completion_response_refusal(self): + """Test response generation with P = 1.0 of refusal""" + response = get_dummy_completion_response(ALL_REFUSAL_CONFIG, RANDOM_SEED) + assert response == ALL_REFUSAL_CONFIG.refusal_text + + def test_get_dummy_chat_response_refusal(self): + """Test response generation with P = 1.0 of refusal""" + response = get_dummy_chat_response(ALL_REFUSAL_CONFIG, RANDOM_SEED) + assert response == ALL_REFUSAL_CONFIG.refusal_text + + def test_get_dummy_completion_response_no_refusal(self): + """Test /completion response generation with P = 0.0 of refusal""" + response = get_dummy_completion_response(NO_REFUSAL_CONFIG) + assert response in set(DUMMY_COMPLETION_RESPONSES) + + def test_get_dummy_chat_response_no_refusal(self): + """Test /chat/completion response with P = 0.0 of refusal.""" + response = get_dummy_chat_response(NO_REFUSAL_CONFIG) + assert response in set(DUMMY_CHAT_RESPONSES) # Edge cases and error handling def test_missing_required_fields_chat(self, client):