diff --git a/nemoguardrails/benchmark/mock_llm_server/__init__.py b/nemoguardrails/benchmark/mock_llm_server/__init__.py new file mode 100644 index 000000000..9ba9d4310 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemoguardrails/benchmark/mock_llm_server/api.py b/nemoguardrails/benchmark/mock_llm_server/api.py new file mode 100644 index 000000000..a33b7505e --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/api.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +from typing import Annotated, Optional, Union + +from fastapi import Depends, FastAPI, HTTPException + +from nemoguardrails.benchmark.mock_llm_server.config import AppModelConfig, get_config +from nemoguardrails.benchmark.mock_llm_server.models import ( + ChatCompletionChoice, + ChatCompletionRequest, + ChatCompletionResponse, + CompletionChoice, + CompletionRequest, + CompletionResponse, + Message, + Model, + ModelsResponse, + Usage, +) +from nemoguardrails.benchmark.mock_llm_server.response_data import ( + DUMMY_MODELS, + calculate_tokens, + generate_id, + get_dummy_chat_response, + get_dummy_completion_response, +) + + +def _validate_request_model( + request: Union[CompletionRequest, ChatCompletionRequest], +) -> None: + """Check the Completion or Chat Completion `model` field is in our supported model list""" + available_models = set([model["id"] for model in DUMMY_MODELS]) + if request.model not in available_models: + raise HTTPException( + status_code=400, + detail=f"Model '{request.model}' not found. Available models: {available_models}", + ) + + +app = FastAPI( + title="Mock LLM Server", + description="OpenAI-compatible mock LLM server for testing and benchmarking", + version="0.0.1", +) + + +ModelConfigDep = Annotated[AppModelConfig, Depends(get_config)] + + +@app.get("/") +async def root(current_config: ModelConfigDep): + """Root endpoint with basic server information.""" + return { + "message": "Mock LLM Server", + "version": "0.0.1", + "description": "OpenAI-compatible mock LLM server for testing and benchmarking", + "endpoints": ["/v1/models", "/v1/chat/completions", "/v1/completions"], + "model_configuration": current_config, + } + + +@app.get("/v1/models", response_model=ModelsResponse) +async def list_models(): + """List available models.""" + return ModelsResponse( + object="list", data=[Model(**model) for model in DUMMY_MODELS] + ) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def chat_completions( + request: ChatCompletionRequest, config: ModelConfigDep +) -> ChatCompletionResponse: + """Create a chat completion.""" + # Validate model exists + _validate_request_model(request) + + # Generate dummy response + response_content = get_dummy_chat_response(config) + + # Calculate token usage + prompt_text = " ".join([msg.content for msg in request.messages]) + prompt_tokens = calculate_tokens(prompt_text) + completion_tokens = calculate_tokens(response_content) + + # Create response + completion_id = generate_id("chatcmpl") + created_timestamp = int(time.time()) + + choices = [] + for i in range(request.n or 1): + choice = ChatCompletionChoice( + index=i, + message=Message(role="assistant", content=response_content), + finish_reason="stop", + ) + choices.append(choice) + + response = ChatCompletionResponse( + id=completion_id, + object="chat.completion", + created=created_timestamp, + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + return response + + +@app.post("/v1/completions", response_model=CompletionResponse) +async def completions( + request: CompletionRequest, config: ModelConfigDep +) -> CompletionResponse: + """Create a text completion.""" + + # Validate model exists + _validate_request_model(request) + + # Handle prompt (can be string or list) + if isinstance(request.prompt, list): + prompt_text = " ".join(request.prompt) + else: + prompt_text = request.prompt + + # Generate dummy response + response_text = get_dummy_completion_response(config) + + # Calculate token usage + prompt_tokens = calculate_tokens(prompt_text) + completion_tokens = calculate_tokens(response_text) + + # Create response + completion_id = generate_id("cmpl") + created_timestamp = int(time.time()) + + choices = [] + for i in range(request.n or 1): + choice = CompletionChoice( + text=response_text, index=i, logprobs=None, finish_reason="stop" + ) + choices.append(choice) + + response = CompletionResponse( + id=completion_id, + object="text_completion", + created=created_timestamp, + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + return response + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy", "timestamp": int(time.time())} diff --git a/nemoguardrails/benchmark/mock_llm_server/config.py b/nemoguardrails/benchmark/mock_llm_server/config.py new file mode 100644 index 000000000..0f1abe7bb --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/config.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import lru_cache +from typing import Any, Optional, Union + +import yaml +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class AppModelConfig(BaseModel): + """Pydantic model to configure the Mock LLM Server.""" + + # Mandatory fields + model: str = Field(..., description="Model name served by mock server") + refusal_text: str = Field(..., description="Refusal response text") + + # Config with default values + refusal_probability: float = Field( + default=0.1, description="Probability of refusal (between 0 and 1)" + ) + # Latency sampled from a truncated-normal distribution. + # Plain Normal distributions have infinite support, and can be negative + latency_min_seconds: float = Field( + default=0.1, description="Minimum latency in seconds" + ) + latency_max_seconds: float = Field( + default=5, description="Maximum latency in seconds" + ) + latency_mean_seconds: float = Field( + default=0.5, description="The average response time in seconds" + ) + latency_std_seconds: float = Field( + default=0.1, description="Standard deviation of response time" + ) + + +settings: Optional[AppModelConfig] = None + + +def load_config(yaml_file: str) -> None: + """Load the Model configuration from YAML file, store in global `settings` var""" + global settings + with open(yaml_file, "r") as f: + config_data = yaml.safe_load(f) + settings = AppModelConfig(**config_data) + + +def get_config() -> AppModelConfig: + """FastAPI Dependency to inject model configuration""" + if settings is None: + raise RuntimeError("No configuration loaded") + return settings diff --git a/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml new file mode 100644 index 000000000..2eebb1063 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/configs/nvidia-llama-3.1-nemoguard-8b-content-safety.yaml @@ -0,0 +1,12 @@ +model: "nvidia/llama-3.1-nemoguard-8b-content-safety" +refusal_probability: 0.01 +refusal_text: | + { + "User Safety": "unsafe", + "Response Safety": "unsafe", + "Safety Categories": "PII/Privacy" + } +latency_min_seconds: 0.1 +latency_max_seconds: 5 +latency_mean_seconds: 0.4 +latency_std_seconds: 0.1 diff --git a/nemoguardrails/benchmark/mock_llm_server/example_usage.py b/nemoguardrails/benchmark/mock_llm_server/example_usage.py new file mode 100644 index 000000000..278ab8d94 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/example_usage.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage of the Mock LLM Server. + +This script demonstrates how to interact with the running mock server +using standard HTTP requests and the OpenAI Python client. +""" + +import json +import time + +import requests + + +def test_with_requests(): + """Test the server using the requests library.""" + base_url = "http://localhost:8000" + + print("Testing Mock LLM Server with requests library...") + print("=" * 50) + + # Test health endpoint + try: + response = requests.get(f"{base_url}/health", timeout=5) + print(f"Health check: {response.status_code} - {response.json()}") + except requests.RequestException as e: + print(f"Health check failed: {e}") + print("Make sure the server is running: python run_server.py") + return + + # Test models endpoint + try: + response = requests.get(f"{base_url}/v1/models", timeout=5) + print(f"\\nModels: {response.status_code}") + models_data = response.json() + for model in models_data["data"]: + print(f" - {model['id']}") + except requests.RequestException as e: + print(f"Models request failed: {e}") + + # Test chat completion + try: + chat_payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + } + response = requests.post( + f"{base_url}/v1/chat/completions", + json=chat_payload, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + print(f"\\nChat completion: {response.status_code}") + if response.status_code == 200: + data = response.json() + print(f"Response: {data['choices'][0]['message']['content']}") + print(f"Usage: {data['usage']}") + except requests.RequestException as e: + print(f"Chat completion failed: {e}") + + # Test text completion + try: + completion_payload = { + "model": "text-davinci-003", + "prompt": "The capital of France is", + "max_tokens": 10, + } + response = requests.post( + f"{base_url}/v1/completions", + json=completion_payload, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + print(f"\\nText completion: {response.status_code}") + if response.status_code == 200: + data = response.json() + print(f"Response: {data['choices'][0]['text']}") + print(f"Usage: {data['usage']}") + except requests.RequestException as e: + print(f"Text completion failed: {e}") + + +def test_with_openai_client(): + """Test the server using the OpenAI Python client.""" + try: + import openai + except ImportError: + print("\\nOpenAI client not available. Install with: pip install openai") + return + + print("\\n" + "=" * 50) + print("Testing with OpenAI client library...") + print("=" * 50) + + # Configure client to use local server + client = openai.OpenAI( + base_url="http://localhost:8000/v1", + api_key="dummy-key", # Server doesn't validate, but client requires it + ) + + try: + # Test chat completion + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello from OpenAI client!"}], + ) + print(f"Chat completion response: {response.choices[0].message.content}") + print( + f"Usage: prompt={response.usage.prompt_tokens}, completion={response.usage.completion_tokens}" + ) + + # Test text completion (if supported by client version) + try: + response = client.completions.create( + model="text-davinci-003", prompt="OpenAI client test: ", max_tokens=10 + ) + print(f"Text completion response: {response.choices[0].text}") + except Exception as e: + print(f"Text completion not supported in this OpenAI client version: {e}") + + except Exception as e: + print(f"OpenAI client test failed: {e}") + + +def benchmark_performance(): + """Simple performance benchmark.""" + print("\\n" + "=" * 50) + print("Performance Benchmark") + print("=" * 50) + + base_url = "http://localhost:8000" + num_requests = 10 + + chat_payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Benchmark test"}], + "max_tokens": 20, + } + + print(f"Making {num_requests} chat completion requests...") + + start_time = time.time() + successful_requests = 0 + + for i in range(num_requests): + try: + response = requests.post( + f"{base_url}/v1/chat/completions", + json=chat_payload, + headers={"Content-Type": "application/json"}, + timeout=5, + ) + if response.status_code == 200: + successful_requests += 1 + except requests.RequestException: + pass + + end_time = time.time() + total_time = end_time - start_time + + print(f"Results:") + print(f" Total requests: {num_requests}") + print(f" Successful requests: {successful_requests}") + print(f" Total time: {total_time:.2f} seconds") + print(f" Average time per request: {total_time/num_requests:.3f} seconds") + print(f" Requests per second: {num_requests/total_time:.2f}") + + +def main(): + """Main function to run all tests.""" + print("Mock LLM Server Example Usage") + print("=" * 50) + print("Make sure the server is running before running this script:") + print(" python run_server.py") + print() + + # Test with requests + test_with_requests() + + # Test with OpenAI client + test_with_openai_client() + + # Simple benchmark + benchmark_performance() + + print("\\nExample completed!") + + +if __name__ == "__main__": + main() diff --git a/nemoguardrails/benchmark/mock_llm_server/models.py b/nemoguardrails/benchmark/mock_llm_server/models.py new file mode 100644 index 000000000..8634c46a6 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/models.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Union + +from pydantic import BaseModel, Field + + +class Message(BaseModel): + """Chat message model.""" + + role: str = Field(..., description="The role of the message author") + content: str = Field(..., description="The content of the message") + + +class ChatCompletionRequest(BaseModel): + """Chat completion request model.""" + + model: str = Field(..., description="ID of the model to use") + messages: list[Message] = Field( + ..., description="List of messages comprising the conversation" + ) + max_tokens: Optional[int] = Field( + None, description="Maximum number of tokens to generate", ge=1 + ) + temperature: Optional[float] = Field( + 1.0, description="Sampling temperature", ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 + ) + n: Optional[int] = Field( + 1, description="Number of completions to generate", ge=1, le=128 + ) + stream: Optional[bool] = Field( + False, description="Whether to stream back partial progress" + ) + stop: Optional[Union[str, list[str]]] = Field( + None, description="Sequences where the API will stop generating" + ) + presence_penalty: Optional[float] = Field( + 0.0, description="Presence penalty", ge=-2.0, le=2.0 + ) + frequency_penalty: Optional[float] = Field( + 0.0, description="Frequency penalty", ge=-2.0, le=2.0 + ) + logit_bias: Optional[dict[str, float]] = Field( + None, description="Modify likelihood of specified tokens" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + + +class CompletionRequest(BaseModel): + """Text completion request model.""" + + model: str = Field(..., description="ID of the model to use") + prompt: Union[str, list[str]] = Field( + ..., description="The prompt(s) to generate completions for" + ) + max_tokens: Optional[int] = Field( + 16, description="Maximum number of tokens to generate", ge=1 + ) + temperature: Optional[float] = Field( + 1.0, description="Sampling temperature", ge=0.0, le=2.0 + ) + top_p: Optional[float] = Field( + 1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0 + ) + n: Optional[int] = Field( + 1, description="Number of completions to generate", ge=1, le=128 + ) + stream: Optional[bool] = Field( + False, description="Whether to stream back partial progress" + ) + logprobs: Optional[int] = Field( + None, description="Include log probabilities", ge=0, le=5 + ) + echo: Optional[bool] = Field( + False, description="Echo back the prompt in addition to completion" + ) + stop: Optional[Union[str, list[str]]] = Field( + None, description="Sequences where the API will stop generating" + ) + presence_penalty: Optional[float] = Field( + 0.0, description="Presence penalty", ge=-2.0, le=2.0 + ) + frequency_penalty: Optional[float] = Field( + 0.0, description="Frequency penalty", ge=-2.0, le=2.0 + ) + best_of: Optional[int] = Field( + 1, description="Number of completions to generate server-side", ge=1 + ) + logit_bias: Optional[dict[str, float]] = Field( + None, description="Modify likelihood of specified tokens" + ) + user: Optional[str] = Field( + None, description="Unique identifier representing your end-user" + ) + + +class Usage(BaseModel): + """Token usage information.""" + + prompt_tokens: int = Field(..., description="Number of tokens in the prompt") + completion_tokens: int = Field( + ..., description="Number of tokens in the completion" + ) + total_tokens: int = Field(..., description="Total number of tokens used") + + +class ChatCompletionChoice(BaseModel): + """Chat completion choice.""" + + index: int = Field(..., description="The index of this choice") + message: Message = Field(..., description="The generated message") + finish_reason: str = Field( + ..., description="The reason the model stopped generating" + ) + + +class CompletionChoice(BaseModel): + """Text completion choice.""" + + text: str = Field(..., description="The generated text") + index: int = Field(..., description="The index of this choice") + logprobs: Optional[dict[str, Any]] = Field( + None, description="Log probability information" + ) + finish_reason: str = Field( + ..., description="The reason the model stopped generating" + ) + + +class ChatCompletionResponse(BaseModel): + """Chat completion response.""" + + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("chat.completion", description="Object type") + created: int = Field( + ..., description="Unix timestamp when the completion was created" + ) + model: str = Field(..., description="The model used for completion") + choices: list[ChatCompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: Usage = Field(..., description="Token usage information") + + +class CompletionResponse(BaseModel): + """Text completion response.""" + + id: str = Field(..., description="Unique identifier for the completion") + object: str = Field("text_completion", description="Object type") + created: int = Field( + ..., description="Unix timestamp when the completion was created" + ) + model: str = Field(..., description="The model used for completion") + choices: list[CompletionChoice] = Field( + ..., description="List of completion choices" + ) + usage: Usage = Field(..., description="Token usage information") + + +class Model(BaseModel): + """Model information.""" + + id: str = Field(..., description="Model identifier") + object: str = Field("model", description="Object type") + created: int = Field(..., description="Unix timestamp when the model was created") + owned_by: str = Field(..., description="Organization that owns the model") + + +class ModelsResponse(BaseModel): + """Models list response.""" + + object: str = Field("list", description="Object type") + data: list[Model] = Field(..., description="List of available models") diff --git a/nemoguardrails/benchmark/mock_llm_server/response_data.py b/nemoguardrails/benchmark/mock_llm_server/response_data.py new file mode 100644 index 000000000..38522583a --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/response_data.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import uuid +from typing import Optional + +import numpy as np + +from nemoguardrails.benchmark.mock_llm_server.config import AppModelConfig, get_config + +DUMMY_MODELS = [ + { + "id": "gpt-3.5-turbo", + "object": "model", + "created": 1677610602, + "owned_by": "openai", + }, + {"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"}, + { + "id": "gpt-4-turbo", + "object": "model", + "created": 1712361441, + "owned_by": "openai", + }, + { + "id": "text-davinci-003", + "object": "model", + "created": 1669599635, + "owned_by": "openai", + }, +] + +DUMMY_CHAT_RESPONSES = [ + "This is a mock response from the LLM server.", + "I'm a dummy AI assistant created for testing purposes.", + "This response is generated by a mock OpenAI-compatible server.", + "Hello! I'm responding with dummy data for benchmarking.", + "This is a simulated conversation response for testing.", +] + +DUMMY_COMPLETION_RESPONSES = [ + "This is a dummy text completion.", + "Here's some mock generated text.", + "This is a sample completion response.", + "Mock completion text for testing purposes.", + "Dummy text generated by the mock server.", +] + + +def generate_id(prefix: str = "chatcmpl") -> str: + """Generate a unique ID for completions.""" + return f"{prefix}-{uuid.uuid4().hex[:8]}" + + +def calculate_tokens(text: str) -> int: + """Rough token calculation (approximately 4 characters per token).""" + return max(1, len(text) // 4) + + +def get_dummy_chat_response(config: AppModelConfig, seed: Optional[int] = None) -> str: + """Get a dummy chat response.""" + + if is_refusal(config, seed): + return config.refusal_text + + return random.choice(DUMMY_CHAT_RESPONSES) + + +def get_dummy_completion_response( + config: AppModelConfig, seed: Optional[int] = None +) -> str: + """Get a dummy completion response.""" + if is_refusal(config, seed): + return config.refusal_text + + return random.choice(DUMMY_COMPLETION_RESPONSES) + + +def get_latency_seconds(config: AppModelConfig, seed: Optional[int] = None) -> float: + """Sample latency for this request using the model's config + Very inefficient to generate each sample singly rather than in batch + """ + if seed: + np.random.seed(seed) + + # Sample from the normal distribution using model config + latency_seconds = np.random.normal( + loc=config.latency_mean_seconds, scale=config.latency_std_seconds, size=1 + ) + + # Truncate distribution's support using min and max config values + latency_seconds = np.clip( + latency_seconds, + a_min=config.latency_min_seconds, + a_max=config.latency_max_seconds, + ) + return float(latency_seconds) + + +def is_refusal(config: AppModelConfig, seed: Optional[int] = None) -> bool: + """Check if the model should return a refusal + Very inefficient to generate each sample singly rather than in batch + """ + if seed: + np.random.seed(seed) + + refusal = np.random.binomial(n=1, p=config.refusal_probability, size=1) + return bool(refusal[0]) diff --git a/nemoguardrails/benchmark/mock_llm_server/run_server.py b/nemoguardrails/benchmark/mock_llm_server/run_server.py new file mode 100644 index 000000000..0d05756d2 --- /dev/null +++ b/nemoguardrails/benchmark/mock_llm_server/run_server.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Startup script for the Mock LLM Server. + +This script starts the FastAPI server with configurable host and port settings. +""" + +import argparse +import sys + +import uvicorn + +from nemoguardrails.benchmark.mock_llm_server.config import get_config, load_config + + +def main(): + parser = argparse.ArgumentParser(description="Run the Mock LLM Server") + parser.add_argument( + "--host", + default="0.0.0.0", + help="Host to bind the server to (default: 0.0.0.0)", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to bind the server to (default: 8000)", + ) + parser.add_argument( + "--reload", action="store_true", help="Enable auto-reload for development" + ) + parser.add_argument( + "--log-level", + default="info", + choices=["critical", "error", "warning", "info", "debug", "trace"], + help="Log level (default: info)", + ) + + parser.add_argument( + "--config-file", help="YAML file to configure model", required=True + ) + + args = parser.parse_args() + + # Load model configuration + load_config(args.config_file) + model_config = get_config() + + # Import the app after configuration is loaded. This caches the values in the app Dependencies + from nemoguardrails.benchmark.mock_llm_server.api import app + + print(f"Starting Mock LLM Server on {args.host}:{args.port}") + print(f"OpenAPI docs available at: http://{args.host}:{args.port}/docs") + print(f"Health check at: http://{args.host}:{args.port}/health") + print(f"Model configuration: {model_config}") + print("Press Ctrl+C to stop the server") + + try: + uvicorn.run( + app=app, + host=args.host, + port=args.port, + reload=args.reload, + log_level=args.log_level, + ) + except KeyboardInterrupt: + print("\nServer stopped by user") + except Exception as e: # pylint: disable=broad-except + print(f"Error starting server: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/benchmark/mock_model_config.yaml b/tests/benchmark/mock_model_config.yaml new file mode 100644 index 000000000..384a988e5 --- /dev/null +++ b/tests/benchmark/mock_model_config.yaml @@ -0,0 +1,3 @@ +model: "mock_model" +refusal_probability: 0.01 +refusal_text: "I'm sorry, I can't help you with that request" diff --git a/tests/benchmark/test_mock_llm_server.py b/tests/benchmark/test_mock_llm_server.py new file mode 100644 index 000000000..552eb57e1 --- /dev/null +++ b/tests/benchmark/test_mock_llm_server.py @@ -0,0 +1,614 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for the Mock LLM FastAPI Server. + +This module contains comprehensive tests for all endpoints and edge cases +of the OpenAI-compatible mock LLM server. +""" + +import json +import os +import time +from typing import Any, Dict, List +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +from nemoguardrails.benchmark.mock_llm_server.api import app +from nemoguardrails.benchmark.mock_llm_server.config import ( + AppModelConfig, + get_config, + load_config, +) +from nemoguardrails.benchmark.mock_llm_server.response_data import ( + DUMMY_CHAT_RESPONSES, + DUMMY_COMPLETION_RESPONSES, + DUMMY_MODELS, + calculate_tokens, + generate_id, + get_dummy_chat_response, + get_dummy_completion_response, +) + +RANDOM_SEED = 12345 +REFUSAL_TEXT = "I'm sorry Dave, I'm afraid I can't do that" +NO_REFUSAL_CONFIG = AppModelConfig( + model="mock-model", + refusal_text=REFUSAL_TEXT, + refusal_probability=0.0, +) + +ALL_REFUSAL_CONFIG = AppModelConfig( + model="mock-model", + refusal_text=REFUSAL_TEXT, + refusal_probability=1.0, +) + + +class TestMockLLMServer: + """Test class for the Mock LLM Server.""" + + @pytest.fixture + def client(self): + """Create a test client for the FastAPI app.""" + return TestClient(app) + + @pytest.fixture + def valid_chat_request(self): + """Sample valid chat completion request.""" + return { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 50, + "temperature": 0.7, + } + + @pytest.fixture + def valid_completion_request(self): + """Sample valid text completion request.""" + return { + "model": "text-davinci-003", + "prompt": "The capital of France is", + "max_tokens": 10, + "temperature": 0.8, + } + + # Root endpoint tests + def test_root_endpoint(self, client): + """Test the root endpoint returns correct information.""" + + mock_config = AppModelConfig( + model="mock_config_model_name", + refusal_text="I'm afraid I can't do that, Dave", + ) + + def override_get_config(): + return mock_config + + app.dependency_overrides[get_config] = override_get_config + + response = client.get("/") + assert response.status_code == 200 + + data = response.json() + assert data["message"] == "Mock LLM Server" + assert data["version"] == "0.0.1" + assert "description" in data + assert "/v1/models" in data["endpoints"] + assert "/v1/chat/completions" in data["endpoints"] + assert "/v1/completions" in data["endpoints"] + assert data["model_configuration"]["model"] == mock_config.model + assert data["model_configuration"]["refusal_text"] == mock_config.refusal_text + + # Health check tests + def test_health_check(self, client): + """Test the health check endpoint.""" + response = client.get("/health") + assert response.status_code == 200 + + data = response.json() + assert data["status"] == "healthy" + assert "timestamp" in data + assert isinstance(data["timestamp"], int) + + # Models endpoint tests + def test_list_models(self, client): + """Test the models listing endpoint.""" + response = client.get("/v1/models") + assert response.status_code == 200 + + data = response.json() + assert data["object"] == "list" + assert isinstance(data["data"], list) + assert len(data["data"]) == len(DUMMY_MODELS) + + # Check first model structure + model = data["data"][0] + assert "id" in model + assert "object" in model + assert "created" in model + assert "owned_by" in model + assert model["object"] == "model" + + def test_models_contain_expected_models(self, client): + """Test that all expected models are returned.""" + response = client.get("/v1/models") + data = response.json() + + model_ids = [model["id"] for model in data["data"]] + expected_ids = [model["id"] for model in DUMMY_MODELS] + + assert set(model_ids) == set(expected_ids) + + # Chat completions tests + def test_chat_completions_success(self, client, valid_chat_request): + """Test successful chat completion request.""" + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 200 + + data = response.json() + assert data["object"] == "chat.completion" + assert data["model"] == valid_chat_request["model"] + assert "id" in data + assert "created" in data + assert isinstance(data["created"], int) + + # Check choices + assert "choices" in data + assert len(data["choices"]) == 1 + choice = data["choices"][0] + assert choice["index"] == 0 + assert choice["finish_reason"] == "stop" + assert "message" in choice + assert choice["message"]["role"] == "assistant" + assert isinstance(choice["message"]["content"], str) + + # Check usage + assert "usage" in data + usage = data["usage"] + assert "prompt_tokens" in usage + assert "completion_tokens" in usage + assert "total_tokens" in usage + assert ( + usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + ) + + def test_chat_completions_multiple_choices(self, client, valid_chat_request): + """Test chat completion with multiple choices.""" + valid_chat_request["n"] = 3 + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 200 + + data = response.json() + assert len(data["choices"]) == 3 + + for i, choice in enumerate(data["choices"]): + assert choice["index"] == i + assert choice["finish_reason"] == "stop" + + def test_chat_completions_invalid_model(self, client, valid_chat_request): + """Test chat completion with invalid model.""" + valid_chat_request["model"] = "invalid-model" + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 400 + + data = response.json() + assert "detail" in data + assert "invalid-model" in data["detail"] + assert "not found" in data["detail"] + + def test_chat_completions_empty_messages(self, client): + """Test chat completion with empty messages.""" + request_data = { + "model": "gpt-3.5-turbo", + "messages": [], + } + response = client.post("/v1/chat/completions", json=request_data) + # Note: The server currently accepts empty messages and processes them + # This may be acceptable behavior for a mock server + assert response.status_code in [ + 200, + 422, + ] # Allow both success and validation error + + def test_chat_completions_invalid_message_format(self, client): + """Test chat completion with invalid message format.""" + request_data = { + "model": "gpt-3.5-turbo", + "messages": [{"invalid": "format"}], + } + response = client.post("/v1/chat/completions", json=request_data) + assert response.status_code == 422 # Validation error + + def test_chat_completions_parameter_validation(self, client, valid_chat_request): + """Test parameter validation for chat completions.""" + # Test max_tokens validation + valid_chat_request["max_tokens"] = 0 + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 422 + + # Test temperature validation + valid_chat_request["max_tokens"] = 50 + valid_chat_request["temperature"] = 3.0 # Out of range + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 422 + + # Test n validation + valid_chat_request["temperature"] = 0.7 + valid_chat_request["n"] = 200 # Out of range + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 422 + + def test_chat_completions_optional_parameters(self, client): + """Test chat completion with various optional parameters.""" + request_data = { + "model": "gpt-4", + "messages": [{"role": "user", "content": "Test message"}], + "max_tokens": 100, + "temperature": 0.5, + "top_p": 0.9, + "presence_penalty": 0.1, + "frequency_penalty": 0.2, + "stop": ["\\n"], + "user": "test-user", + } + response = client.post("/v1/chat/completions", json=request_data) + assert response.status_code == 200 + + # Text completions tests + def test_completions_success(self, client, valid_completion_request): + """Test successful text completion request.""" + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 200 + + data = response.json() + assert data["object"] == "text_completion" + assert data["model"] == valid_completion_request["model"] + assert "id" in data + assert "created" in data + + # Check choices + assert "choices" in data + assert len(data["choices"]) == 1 + choice = data["choices"][0] + assert choice["index"] == 0 + assert choice["finish_reason"] == "stop" + assert "text" in choice + assert isinstance(choice["text"], str) + + # Check usage + assert "usage" in data + usage = data["usage"] + assert "prompt_tokens" in usage + assert "completion_tokens" in usage + assert "total_tokens" in usage + + def test_completions_list_prompt(self, client): + """Test text completion with list prompt.""" + request_data = { + "model": "text-davinci-003", + "prompt": ["First prompt", "Second prompt"], + "max_tokens": 10, + } + response = client.post("/v1/completions", json=request_data) + assert response.status_code == 200 + + data = response.json() + assert data["object"] == "text_completion" + + def test_completions_invalid_model(self, client, valid_completion_request): + """Test text completion with invalid model.""" + valid_completion_request["model"] = "non-existent-model" + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 400 + + def test_completions_multiple_choices(self, client, valid_completion_request): + """Test text completion with multiple choices.""" + valid_completion_request["n"] = 2 + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 200 + + data = response.json() + assert len(data["choices"]) == 2 + + def test_completions_parameter_validation(self, client, valid_completion_request): + """Test parameter validation for text completions.""" + # Test max_tokens validation + valid_completion_request["max_tokens"] = -1 + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 422 + + # Test temperature validation + valid_completion_request["max_tokens"] = 10 + valid_completion_request["temperature"] = -1.0 + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 422 + + def test_completions_optional_parameters(self, client): + """Test text completion with various optional parameters.""" + request_data = { + "model": "gpt-3.5-turbo", + "prompt": "Test prompt", + "max_tokens": 50, + "temperature": 0.8, + "top_p": 0.95, + "n": 1, + "logprobs": 1, + "echo": True, + "stop": ["\\n", "."], + "presence_penalty": -0.5, + "frequency_penalty": 0.3, + "best_of": 2, + "user": "test-user-2", + } + response = client.post("/v1/completions", json=request_data) + assert response.status_code == 200 + + # Helper function tests + def test_generate_id_default(self): + """Test ID generation with default prefix.""" + id1 = generate_id() + id2 = generate_id() + + assert id1.startswith("chatcmpl-") + assert id2.startswith("chatcmpl-") + assert id1 != id2 # Should be unique + assert len(id1) == len("chatcmpl-") + 8 # prefix + 8 hex chars + + def test_generate_id_custom_prefix(self): + """Test ID generation with custom prefix.""" + custom_id = generate_id("cmpl") + assert custom_id.startswith("cmpl-") + assert len(custom_id) == len("cmpl-") + 8 + + def test_calculate_tokens(self): + """Test token calculation function.""" + # Test basic calculation + assert calculate_tokens("") == 1 # Minimum 1 token + assert calculate_tokens("a") == 1 + assert calculate_tokens("abcd") == 1 + assert calculate_tokens("abcde") == 1 # 5 chars = 1 token (rounded down) + assert calculate_tokens("abcdefgh") == 2 # 8 chars = 2 tokens + + # Test longer text + long_text = "This is a longer text with multiple words and characters." + expected_tokens = max(1, len(long_text) // 4) + assert calculate_tokens(long_text) == expected_tokens + + def test_get_dummy_completion_response_refusal(self): + """Test response generation with P = 1.0 of refusal""" + response = get_dummy_completion_response(ALL_REFUSAL_CONFIG, RANDOM_SEED) + assert response == ALL_REFUSAL_CONFIG.refusal_text + + def test_get_dummy_chat_response_refusal(self): + """Test response generation with P = 1.0 of refusal""" + response = get_dummy_chat_response(ALL_REFUSAL_CONFIG, RANDOM_SEED) + assert response == ALL_REFUSAL_CONFIG.refusal_text + + def test_get_dummy_completion_response_no_refusal(self): + """Test /completion response generation with P = 0.0 of refusal""" + response = get_dummy_completion_response(NO_REFUSAL_CONFIG) + assert response in set(DUMMY_COMPLETION_RESPONSES) + + def test_get_dummy_chat_response_no_refusal(self): + """Test /chat/completion response with P = 0.0 of refusal.""" + response = get_dummy_chat_response(NO_REFUSAL_CONFIG) + assert response in set(DUMMY_CHAT_RESPONSES) + + # Edge cases and error handling + def test_missing_required_fields_chat(self, client): + """Test chat completion with missing required fields.""" + # Missing model + response = client.post("/v1/chat/completions", json={"messages": []}) + assert response.status_code == 422 + + # Missing messages + response = client.post("/v1/chat/completions", json={"model": "gpt-3.5-turbo"}) + assert response.status_code == 422 + + def test_missing_required_fields_completion(self, client): + """Test text completion with missing required fields.""" + # Missing model + response = client.post("/v1/completions", json={"prompt": "test"}) + assert response.status_code == 422 + + # Missing prompt + response = client.post("/v1/completions", json={"model": "gpt-3.5-turbo"}) + assert response.status_code == 422 + + def test_invalid_json(self, client): + """Test endpoints with invalid JSON.""" + response = client.post( + "/v1/chat/completions", + content="invalid json", + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 422 + + def test_empty_request_body(self, client): + """Test endpoints with empty request body.""" + response = client.post("/v1/chat/completions", json={}) + assert response.status_code == 422 + + response = client.post("/v1/completions", json={}) + assert response.status_code == 422 + + # Content validation tests + def test_chat_message_content_types(self, client): + """Test chat completion with different message content types.""" + # Test with multiple messages + request_data = { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ], + } + response = client.post("/v1/chat/completions", json=request_data) + assert response.status_code == 200 + + def test_response_structure_consistency(self, client, valid_chat_request): + """Test that response structure is consistent across calls.""" + response1 = client.post("/v1/chat/completions", json=valid_chat_request) + response2 = client.post("/v1/chat/completions", json=valid_chat_request) + + assert response1.status_code == 200 + assert response2.status_code == 200 + + data1 = response1.json() + data2 = response2.json() + + # Structure should be the same + assert set(data1.keys()) == set(data2.keys()) + assert data1["object"] == data2["object"] + assert data1["model"] == data2["model"] + + # IDs should be different + assert data1["id"] != data2["id"] + + def test_concurrent_requests(self, client, valid_chat_request): + """Test handling of concurrent requests.""" + import threading + import time + + results = [] + + def make_request(): + response = client.post("/v1/chat/completions", json=valid_chat_request) + results.append(response.status_code) + + # Create multiple threads + threads = [] + for _ in range(5): + thread = threading.Thread(target=make_request) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # All requests should be successful + assert all(status == 200 for status in results) + assert len(results) == 5 + + # Performance and load tests + def test_response_time_reasonable(self, client, valid_chat_request): + """Test that response times are reasonable.""" + start_time = time.time() + response = client.post("/v1/chat/completions", json=valid_chat_request) + end_time = time.time() + + assert response.status_code == 200 + assert (end_time - start_time) < 1.0 # Should respond within 1 second + + def test_large_prompt_handling(self, client): + """Test handling of large prompts.""" + large_prompt = "A" * 10000 # 10K characters + request_data = { + "model": "text-davinci-003", + "prompt": large_prompt, + "max_tokens": 10, + } + response = client.post("/v1/completions", json=request_data) + assert response.status_code == 200 + + data = response.json() + # Token calculation should handle large text + assert data["usage"]["prompt_tokens"] > 1000 + + # Mock and patch tests + @patch("nemoguardrails.benchmark.mock_llm_server.api.get_dummy_chat_response") + def test_chat_completion_response_mocking( + self, mock_response, client, valid_chat_request + ): + """Test mocking of chat response generation.""" + expected_response = "Mocked response for testing chat completions" + mock_response.return_value = expected_response + + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 200 + + data = response.json() + assert data["choices"][0]["message"]["content"] == expected_response + mock_response.assert_called_once() + + @patch("nemoguardrails.benchmark.mock_llm_server.api.get_dummy_completion_response") + def test_completion_response_mocking( + self, mock_response, client, valid_completion_request + ): + """Test mocking of chat response generation.""" + expected_response = "Mocked response to check completion responses" + mock_response.return_value = expected_response + + response = client.post("/v1/completions", json=valid_completion_request) + assert response.status_code == 200 + + data = response.json() + assert data["choices"][0]["text"] == expected_response + mock_response.assert_called_once() + + @patch("time.time") + def test_timestamp_consistency(self, mock_time, client, valid_chat_request): + """Test that timestamps are generated correctly.""" + mock_time.return_value = 1234567890 + + response = client.post("/v1/chat/completions", json=valid_chat_request) + assert response.status_code == 200 + + data = response.json() + assert data["created"] == 1234567890 + + # Documentation and OpenAPI tests + def test_openapi_docs_available(self, client): + """Test that OpenAPI documentation is available.""" + response = client.get("/docs") + assert response.status_code == 200 + + response = client.get("/openapi.json") + assert response.status_code == 200 + + openapi_data = response.json() + assert "openapi" in openapi_data + assert "paths" in openapi_data + assert "/v1/models" in openapi_data["paths"] + assert "/v1/chat/completions" in openapi_data["paths"] + assert "/v1/completions" in openapi_data["paths"] + + def test_read_root_with_mock_config(self): + """Tests load_config method correctly populates the `settings` global variable""" + yaml_file = os.path.join(os.path.dirname(__file__), "mock_model_config.yaml") + + # Make sure settings is empty to start with, load and check it's populated + load_config(yaml_file) + config = get_config() + assert config is not None + + # Now check the contents against `mock_model_config.yaml` + assert isinstance(config, AppModelConfig) + assert config.model == "mock_model" + assert config.refusal_probability == 0.01 + assert config.refusal_text == "I'm sorry, I can't help you with that request" + + @patch("nemoguardrails.benchmark.mock_llm_server.config.settings", None) + def test_get_config_raises_exception(self): + """Check if we call `get_config()` without settings set we raise an exception""" + with pytest.raises(RuntimeError, match="No configuration loaded"): + get_config()