diff --git a/pyproject.toml b/pyproject.toml index eaa9a6cb..1e262ff6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,8 +45,8 @@ Issues = "https://github.com/holegots/claude-code-proxy/issues" [project.scripts] claude-code-proxy = "src.main:main" -[tool.uv] -dev-dependencies = [ +[dependency-groups] +dev = [ "pytest>=7.0.0", "pytest-asyncio>=0.21.0", "black>=23.0.0", diff --git a/src/api/endpoints.py b/src/api/endpoints.py index 5d03a8a0..5c8a68ea 100644 --- a/src/api/endpoints.py +++ b/src/api/endpoints.py @@ -7,7 +7,11 @@ from src.core.config import config from src.core.logging import logger from src.core.client import OpenAIClient -from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest +from src.models.claude import ( + ClaudeMessagesRequest, + ClaudeTokenCountRequest, + EventLoggingBatchResponse, +) from src.conversion.request_converter import convert_claude_to_openai from src.conversion.response_converter import ( convert_openai_to_claude_response, @@ -28,34 +32,37 @@ custom_headers=custom_headers, ) -async def validate_api_key(x_api_key: Optional[str] = Header(None), authorization: Optional[str] = Header(None)): + +async def validate_api_key( + x_api_key: Optional[str] = Header(None), authorization: Optional[str] = Header(None) +): """Validate the client's API key from either x-api-key header or Authorization header.""" client_api_key = None - + # Extract API key from headers if x_api_key: client_api_key = x_api_key elif authorization and authorization.startswith("Bearer "): client_api_key = authorization.replace("Bearer ", "") - + # Skip validation if ANTHROPIC_API_KEY is not set in the environment if not config.anthropic_api_key: return - + # Validate the client API key if not client_api_key or not config.validate_client_api_key(client_api_key): logger.warning(f"Invalid API key provided by client") raise HTTPException( - status_code=401, - detail="Invalid API key. Please provide a valid Anthropic API key." + status_code=401, detail="Invalid API key. Please provide a valid Anthropic API key." ) + @router.post("/v1/messages") -async def create_message(request: ClaudeMessagesRequest, http_request: Request, _: None = Depends(validate_api_key)): +async def create_message( + request: ClaudeMessagesRequest, http_request: Request, _: None = Depends(validate_api_key) +): try: - logger.debug( - f"Processing Claude request: model={request.model}, stream={request.stream}" - ) + logger.debug(f"Processing Claude request: model={request.model}, stream={request.stream}") # Generate unique request ID for cancellation tracking request_id = str(uuid.uuid4()) @@ -104,12 +111,8 @@ async def create_message(request: ClaudeMessagesRequest, http_request: Request, return JSONResponse(status_code=e.status_code, content=error_response) else: # Non-streaming response - openai_response = await openai_client.create_chat_completion( - openai_request, request_id - ) - claude_response = convert_openai_to_claude_response( - openai_response, request - ) + openai_response = await openai_client.create_chat_completion(openai_request, request_id) + claude_response = convert_openai_to_claude_response(openai_response, request) return claude_response except HTTPException: raise @@ -160,6 +163,51 @@ async def count_tokens(request: ClaudeTokenCountRequest, _: None = Depends(valid raise HTTPException(status_code=500, detail=str(e)) +@router.post("/api/event_logging/batch") +async def event_logging_batch(request: Request) -> EventLoggingBatchResponse: + try: + try: + body = await request.json() + except ValueError: + body = {} + + batch_id = body.get("batch_id") or f"batch_{uuid.uuid4().hex[:16]}" + events = body.get("events", []) + event_items = events if isinstance(events, list) else [events] if events else [] + processed_count = len(event_items) + + if event_items: + logger.debug( + f"Event logging batch received: batch_id={batch_id}, " + f"events_count={processed_count}" + ) + for index, event in enumerate(event_items[:5], start=1): + event_type = ( + event.get("event_type", "unknown") if isinstance(event, dict) else "unknown" + ) + logger.debug(f"Event {index}: type={event_type}") + + remaining_count = processed_count - 5 + if remaining_count > 0: + logger.debug(f"Event logging batch has {remaining_count} additional events") + + return EventLoggingBatchResponse( + success=True, + batch_id=batch_id, + processed_count=processed_count, + message="Events logged successfully", + ) + + except Exception as e: + logger.warning(f"Event logging batch error: {e}") + return EventLoggingBatchResponse( + success=True, + batch_id=f"batch_{uuid.uuid4().hex[:16]}", + processed_count=0, + message="Events received", + ) + + @router.get("/health") async def health_check(): """Health check endpoint""" @@ -228,6 +276,7 @@ async def root(): "endpoints": { "messages": "/v1/messages", "count_tokens": "/v1/messages/count_tokens", + "event_logging_batch": "/api/event_logging/batch", "health": "/health", "test_connection": "/test-connection", }, diff --git a/src/conversion/request_converter.py b/src/conversion/request_converter.py index f341c709..3103f4ae 100644 --- a/src/conversion/request_converter.py +++ b/src/conversion/request_converter.py @@ -1,6 +1,5 @@ import json from typing import Dict, Any, List -from venv import logger from src.core.constants import Constants from src.models.claude import ClaudeMessagesRequest, ClaudeMessage from src.core.config import config @@ -30,17 +29,12 @@ def convert_claude_to_openai( for block in claude_request.system: if hasattr(block, "type") and block.type == Constants.CONTENT_TEXT: text_parts.append(block.text) - elif ( - isinstance(block, dict) - and block.get("type") == Constants.CONTENT_TEXT - ): + elif isinstance(block, dict) and block.get("type") == Constants.CONTENT_TEXT: text_parts.append(block.get("text", "")) system_text = "\n\n".join(text_parts) if system_text.strip(): - openai_messages.append( - {"role": Constants.ROLE_SYSTEM, "content": system_text.strip()} - ) + openai_messages.append({"role": Constants.ROLE_SYSTEM, "content": system_text.strip()}) # Process Claude messages i = 0 @@ -133,7 +127,7 @@ def convert_claude_user_message(msg: ClaudeMessage) -> Dict[str, Any]: """Convert Claude user message to OpenAI format.""" if msg.content is None: return {"role": Constants.ROLE_USER, "content": ""} - + if isinstance(msg.content, str): return {"role": Constants.ROLE_USER, "content": msg.content} @@ -172,7 +166,7 @@ def convert_claude_assistant_message(msg: ClaudeMessage) -> Dict[str, Any]: if msg.content is None: return {"role": Constants.ROLE_ASSISTANT, "content": None} - + if isinstance(msg.content, str): return {"role": Constants.ROLE_ASSISTANT, "content": msg.content} @@ -246,7 +240,7 @@ def parse_tool_result_content(content): else: try: result_parts.append(json.dumps(item, ensure_ascii=False)) - except: + except (TypeError, ValueError): result_parts.append(str(item)) return "\n".join(result_parts).strip() @@ -255,10 +249,10 @@ def parse_tool_result_content(content): return content.get("text", "") try: return json.dumps(content, ensure_ascii=False) - except: + except (TypeError, ValueError): return str(content) try: return str(content) - except: + except Exception: return "Unparseable content" diff --git a/src/conversion/response_converter.py b/src/conversion/response_converter.py index 2980c37d..ba7b5552 100644 --- a/src/conversion/response_converter.py +++ b/src/conversion/response_converter.py @@ -69,9 +69,7 @@ def convert_openai_to_claude_response( "stop_sequence": None, "usage": { "input_tokens": openai_response.get("usage", {}).get("prompt_tokens", 0), - "output_tokens": openai_response.get("usage", {}).get( - "completion_tokens", 0 - ), + "output_tokens": openai_response.get("usage", {}).get("completion_tokens", 0), }, } @@ -112,9 +110,7 @@ async def convert_openai_streaming_to_claude( if not choices: continue except json.JSONDecodeError as e: - logger.warning( - f"Failed to parse chunk: {chunk_data}, error: {e}" - ) + logger.warning(f"Failed to parse chunk: {chunk_data}, error: {e}") continue choice = choices[0] @@ -126,10 +122,10 @@ async def convert_openai_streaming_to_claude( yield f"event: {Constants.EVENT_CONTENT_BLOCK_DELTA}\ndata: {json.dumps({'type': Constants.EVENT_CONTENT_BLOCK_DELTA, 'index': text_block_index, 'delta': {'type': Constants.DELTA_TEXT, 'text': delta['content']}}, ensure_ascii=False)}\n\n" # Handle tool call deltas with improved incremental processing - if "tool_calls" in delta: + if "tool_calls" in delta and delta["tool_calls"]: for tc_delta in delta["tool_calls"]: tc_index = tc_delta.get("index", 0) - + # Initialize tool call tracking by index if not exists if tc_index not in current_tool_calls: current_tool_calls[tc_index] = { @@ -138,33 +134,37 @@ async def convert_openai_streaming_to_claude( "args_buffer": "", "json_sent": False, "claude_index": None, - "started": False + "started": False, } - + tool_call = current_tool_calls[tc_index] - + # Update tool call ID if provided if tc_delta.get("id"): tool_call["id"] = tc_delta["id"] - + # Update function name and start content block if we have both id and name function_data = tc_delta.get(Constants.TOOL_FUNCTION, {}) if function_data.get("name"): tool_call["name"] = function_data["name"] - + # Start content block when we have complete initial data - if (tool_call["id"] and tool_call["name"] and not tool_call["started"]): + if tool_call["id"] and tool_call["name"] and not tool_call["started"]: tool_block_counter += 1 claude_index = text_block_index + tool_block_counter tool_call["claude_index"] = claude_index tool_call["started"] = True - + yield f"event: {Constants.EVENT_CONTENT_BLOCK_START}\ndata: {json.dumps({'type': Constants.EVENT_CONTENT_BLOCK_START, 'index': claude_index, 'content_block': {'type': Constants.CONTENT_TOOL_USE, 'id': tool_call['id'], 'name': tool_call['name'], 'input': {}}}, ensure_ascii=False)}\n\n" - + # Handle function arguments - if "arguments" in function_data and tool_call["started"] and function_data["arguments"] is not None: + if ( + "arguments" in function_data + and tool_call["started"] + and function_data["arguments"] is not None + ): tool_call["args_buffer"] += function_data["arguments"] - + # Try to parse complete JSON and send delta when we have valid JSON try: json.loads(tool_call["args_buffer"]) @@ -259,21 +259,21 @@ async def convert_openai_streaming_to_claude_with_cancellation( usage = chunk.get("usage", None) if usage: cache_read_input_tokens = 0 - prompt_tokens_details = usage.get('prompt_tokens_details', {}) + prompt_tokens_details = usage.get("prompt_tokens_details", {}) if prompt_tokens_details: - cache_read_input_tokens = prompt_tokens_details.get('cached_tokens', 0) + cache_read_input_tokens = prompt_tokens_details.get( + "cached_tokens", 0 + ) usage_data = { - 'input_tokens': usage.get('prompt_tokens', 0), - 'output_tokens': usage.get('completion_tokens', 0), - 'cache_read_input_tokens': cache_read_input_tokens + "input_tokens": usage.get("prompt_tokens", 0), + "output_tokens": usage.get("completion_tokens", 0), + "cache_read_input_tokens": cache_read_input_tokens, } choices = chunk.get("choices", []) if not choices: continue except json.JSONDecodeError as e: - logger.warning( - f"Failed to parse chunk: {chunk_data}, error: {e}" - ) + logger.warning(f"Failed to parse chunk: {chunk_data}, error: {e}") continue choice = choices[0] @@ -288,7 +288,7 @@ async def convert_openai_streaming_to_claude_with_cancellation( if "tool_calls" in delta and delta["tool_calls"]: for tc_delta in delta["tool_calls"]: tc_index = tc_delta.get("index", 0) - + # Initialize tool call tracking by index if not exists if tc_index not in current_tool_calls: current_tool_calls[tc_index] = { @@ -297,33 +297,37 @@ async def convert_openai_streaming_to_claude_with_cancellation( "args_buffer": "", "json_sent": False, "claude_index": None, - "started": False + "started": False, } - + tool_call = current_tool_calls[tc_index] - + # Update tool call ID if provided if tc_delta.get("id"): tool_call["id"] = tc_delta["id"] - + # Update function name and start content block if we have both id and name function_data = tc_delta.get(Constants.TOOL_FUNCTION, {}) if function_data.get("name"): tool_call["name"] = function_data["name"] - + # Start content block when we have complete initial data - if (tool_call["id"] and tool_call["name"] and not tool_call["started"]): + if tool_call["id"] and tool_call["name"] and not tool_call["started"]: tool_block_counter += 1 claude_index = text_block_index + tool_block_counter tool_call["claude_index"] = claude_index tool_call["started"] = True - + yield f"event: {Constants.EVENT_CONTENT_BLOCK_START}\ndata: {json.dumps({'type': Constants.EVENT_CONTENT_BLOCK_START, 'index': claude_index, 'content_block': {'type': Constants.CONTENT_TOOL_USE, 'id': tool_call['id'], 'name': tool_call['name'], 'input': {}}}, ensure_ascii=False)}\n\n" - + # Handle function arguments - if "arguments" in function_data and tool_call["started"] and function_data["arguments"] is not None: + if ( + "arguments" in function_data + and tool_call["started"] + and function_data["arguments"] is not None + ): tool_call["args_buffer"] += function_data["arguments"] - + # Try to parse complete JSON and send delta when we have valid JSON try: json.loads(tool_call["args_buffer"]) @@ -360,7 +364,16 @@ async def convert_openai_streaming_to_claude_with_cancellation( yield f"event: error\ndata: {json.dumps(error_event, ensure_ascii=False)}\n\n" return else: - raise + logger.error(f"HTTP error in streaming: {e.status_code} - {e.detail}") + error_event = { + "type": "error", + "error": { + "type": "api_error", + "message": f"HTTP {e.status_code}: {e.detail}", + }, + } + yield f"event: error\ndata: {json.dumps(error_event, ensure_ascii=False)}\n\n" + return except Exception as e: # Handle any streaming errors gracefully logger.error(f"Streaming error: {e}") diff --git a/src/core/client.py b/src/core/client.py index 73aeafde..7f3446d0 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -3,26 +3,30 @@ from fastapi import HTTPException from typing import Optional, AsyncGenerator, Dict, Any from openai import AsyncOpenAI, AsyncAzureOpenAI -from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai._exceptions import APIError, RateLimitError, AuthenticationError, BadRequestError + class OpenAIClient: """Async OpenAI client with cancellation support.""" - - def __init__(self, api_key: str, base_url: str, timeout: int = 90, api_version: Optional[str] = None, custom_headers: Optional[Dict[str, str]] = None): + + def __init__( + self, + api_key: str, + base_url: str, + timeout: int = 90, + api_version: Optional[str] = None, + custom_headers: Optional[Dict[str, str]] = None, + ): self.api_key = api_key self.base_url = base_url self.custom_headers = custom_headers or {} - + # Prepare default headers - default_headers = { - "Content-Type": "application/json", - "User-Agent": "claude-proxy/1.0.0" - } - + default_headers = {"Content-Type": "application/json", "User-Agent": "claude-proxy/1.0.0"} + # Merge custom headers with default headers all_headers = {**default_headers, **self.custom_headers} - + # Detect if using Azure and instantiate the appropriate client if api_version: self.client = AsyncAzureOpenAI( @@ -30,39 +34,35 @@ def __init__(self, api_key: str, base_url: str, timeout: int = 90, api_version: azure_endpoint=base_url, api_version=api_version, timeout=timeout, - default_headers=all_headers + default_headers=all_headers, ) else: self.client = AsyncOpenAI( - api_key=api_key, - base_url=base_url, - timeout=timeout, - default_headers=all_headers + api_key=api_key, base_url=base_url, timeout=timeout, default_headers=all_headers ) self.active_requests: Dict[str, asyncio.Event] = {} - - async def create_chat_completion(self, request: Dict[str, Any], request_id: Optional[str] = None) -> Dict[str, Any]: + + async def create_chat_completion( + self, request: Dict[str, Any], request_id: Optional[str] = None + ) -> Dict[str, Any]: """Send chat completion to OpenAI API with cancellation support.""" - + # Create cancellation token if request_id provided if request_id: cancel_event = asyncio.Event() self.active_requests[request_id] = cancel_event - + try: # Create task that can be cancelled - completion_task = asyncio.create_task( - self.client.chat.completions.create(**request) - ) - + completion_task = asyncio.create_task(self.client.chat.completions.create(**request)) + if request_id: # Wait for either completion or cancellation cancel_task = asyncio.create_task(cancel_event.wait()) done, pending = await asyncio.wait( - [completion_task, cancel_task], - return_when=asyncio.FIRST_COMPLETED + [completion_task, cancel_task], return_when=asyncio.FIRST_COMPLETED ) - + # Cancel pending tasks for task in pending: task.cancel() @@ -70,19 +70,19 @@ async def create_chat_completion(self, request: Dict[str, Any], request_id: Opti await task except asyncio.CancelledError: pass - + # Check if request was cancelled if cancel_task in done: completion_task.cancel() raise HTTPException(status_code=499, detail="Request cancelled by client") - + completion = await completion_task else: completion = await completion_task - + # Convert to dict format that matches the original interface return completion.model_dump() - + except AuthenticationError as e: raise HTTPException(status_code=401, detail=self.classify_openai_error(str(e))) except RateLimitError as e: @@ -90,48 +90,50 @@ async def create_chat_completion(self, request: Dict[str, Any], request_id: Opti except BadRequestError as e: raise HTTPException(status_code=400, detail=self.classify_openai_error(str(e))) except APIError as e: - status_code = getattr(e, 'status_code', 500) + status_code = getattr(e, "status_code", 500) raise HTTPException(status_code=status_code, detail=self.classify_openai_error(str(e))) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") - + finally: # Clean up active request tracking if request_id and request_id in self.active_requests: del self.active_requests[request_id] - - async def create_chat_completion_stream(self, request: Dict[str, Any], request_id: Optional[str] = None) -> AsyncGenerator[str, None]: + + async def create_chat_completion_stream( + self, request: Dict[str, Any], request_id: Optional[str] = None + ) -> AsyncGenerator[str, None]: """Send streaming chat completion to OpenAI API with cancellation support.""" - + # Create cancellation token if request_id provided if request_id: cancel_event = asyncio.Event() self.active_requests[request_id] = cancel_event - + try: # Ensure stream is enabled request["stream"] = True if "stream_options" not in request: request["stream_options"] = {} request["stream_options"]["include_usage"] = True - + # Create the streaming completion streaming_completion = await self.client.chat.completions.create(**request) - + async for chunk in streaming_completion: # Check for cancellation before yielding each chunk if request_id and request_id in self.active_requests: if self.active_requests[request_id].is_set(): raise HTTPException(status_code=499, detail="Request cancelled by client") - + # Convert chunk to SSE format matching original HTTP client format chunk_dict = chunk.model_dump() chunk_json = json.dumps(chunk_dict, ensure_ascii=False) yield f"data: {chunk_json}" - + # Signal end of stream yield "data: [DONE]" - + except AuthenticationError as e: raise HTTPException(status_code=401, detail=self.classify_openai_error(str(e))) except RateLimitError as e: @@ -139,11 +141,11 @@ async def create_chat_completion_stream(self, request: Dict[str, Any], request_i except BadRequestError as e: raise HTTPException(status_code=400, detail=self.classify_openai_error(str(e))) except APIError as e: - status_code = getattr(e, 'status_code', 500) + status_code = getattr(e, "status_code", 500) raise HTTPException(status_code=status_code, detail=self.classify_openai_error(str(e))) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") - + finally: # Clean up active request tracking if request_id and request_id in self.active_requests: @@ -152,33 +154,36 @@ async def create_chat_completion_stream(self, request: Dict[str, Any], request_i def classify_openai_error(self, error_detail: Any) -> str: """Provide specific error guidance for common OpenAI API issues.""" error_str = str(error_detail).lower() - + # Region/country restrictions - if "unsupported_country_region_territory" in error_str or "country, region, or territory not supported" in error_str: + if ( + "unsupported_country_region_territory" in error_str + or "country, region, or territory not supported" in error_str + ): return "OpenAI API is not available in your region. Consider using a VPN or Azure OpenAI service." - + # API key issues if "invalid_api_key" in error_str or "unauthorized" in error_str: return "Invalid API key. Please check your OPENAI_API_KEY configuration." - + # Rate limiting if "rate_limit" in error_str or "quota" in error_str: return "Rate limit exceeded. Please wait and try again, or upgrade your API plan." - + # Model not found if "model" in error_str and ("not found" in error_str or "does not exist" in error_str): return "Model not found. Please check your BIG_MODEL and SMALL_MODEL configuration." - + # Billing issues if "billing" in error_str or "payment" in error_str: return "Billing issue. Please check your OpenAI account billing status." - + # Default: return original message return str(error_detail) - + def cancel_request(self, request_id: str) -> bool: """Cancel an active request by request_id.""" if request_id in self.active_requests: self.active_requests[request_id].set() return True - return False \ No newline at end of file + return False diff --git a/src/core/config.py b/src/core/config.py index 71bb7c95..553f1fdc 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -1,18 +1,19 @@ import os import sys + # Configuration class Config: def __init__(self): self.openai_api_key = os.environ.get("OPENAI_API_KEY") if not self.openai_api_key: raise ValueError("OPENAI_API_KEY not found in environment variables") - + # Add Anthropic API key for client validation self.anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY") if not self.anthropic_api_key: print("Warning: ANTHROPIC_API_KEY not set. Client API key validation will be disabled.") - + self.openai_base_url = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") self.azure_api_version = os.environ.get("AZURE_API_VERSION") # For Azure OpenAI self.host = os.environ.get("HOST", "0.0.0.0") @@ -20,55 +21,56 @@ def __init__(self): self.log_level = os.environ.get("LOG_LEVEL", "INFO") self.max_tokens_limit = int(os.environ.get("MAX_TOKENS_LIMIT", "4096")) self.min_tokens_limit = int(os.environ.get("MIN_TOKENS_LIMIT", "100")) - + # Connection settings self.request_timeout = int(os.environ.get("REQUEST_TIMEOUT", "90")) self.max_retries = int(os.environ.get("MAX_RETRIES", "2")) - + # Model settings - BIG and SMALL models self.big_model = os.environ.get("BIG_MODEL", "gpt-4o") self.middle_model = os.environ.get("MIDDLE_MODEL", self.big_model) self.small_model = os.environ.get("SMALL_MODEL", "gpt-4o-mini") - + def validate_api_key(self): """Basic API key validation""" if not self.openai_api_key: return False # Basic format check for OpenAI API keys - if not self.openai_api_key.startswith('sk-'): + if not self.openai_api_key.startswith("sk-"): return False return True - + def validate_client_api_key(self, client_api_key): """Validate client's Anthropic API key""" # If no ANTHROPIC_API_KEY is set in environment, skip validation if not self.anthropic_api_key: return True - + # Check if the client's API key matches the expected value return client_api_key == self.anthropic_api_key - + def get_custom_headers(self): """Get custom headers from environment variables""" custom_headers = {} - + # Get all environment variables env_vars = dict(os.environ) - + # Find CUSTOM_HEADER_* environment variables for env_key, env_value in env_vars.items(): - if env_key.startswith('CUSTOM_HEADER_'): + if env_key.startswith("CUSTOM_HEADER_"): # Convert CUSTOM_HEADER_KEY to Header-Key # Remove 'CUSTOM_HEADER_' prefix and convert to header format header_name = env_key[14:] # Remove 'CUSTOM_HEADER_' prefix - + if header_name: # Make sure it's not empty # Convert underscores to hyphens for HTTP header format - header_name = header_name.replace('_', '-') + header_name = header_name.replace("_", "-") custom_headers[header_name] = env_value - + return custom_headers + try: config = Config() print(f" Configuration loaded: API_KEY={'*' * 20}..., BASE_URL='{config.openai_base_url}'") diff --git a/src/core/constants.py b/src/core/constants.py index 737f557c..589895ff 100644 --- a/src/core/constants.py +++ b/src/core/constants.py @@ -1,22 +1,22 @@ -# Constants for better maintainability +# Constants for better maintainability class Constants: ROLE_USER = "user" ROLE_ASSISTANT = "assistant" ROLE_SYSTEM = "system" ROLE_TOOL = "tool" - + CONTENT_TEXT = "text" CONTENT_IMAGE = "image" CONTENT_TOOL_USE = "tool_use" CONTENT_TOOL_RESULT = "tool_result" - + TOOL_FUNCTION = "function" - + STOP_END_TURN = "end_turn" STOP_MAX_TOKENS = "max_tokens" STOP_TOOL_USE = "tool_use" STOP_ERROR = "error" - + EVENT_MESSAGE_START = "message_start" EVENT_MESSAGE_STOP = "message_stop" EVENT_MESSAGE_DELTA = "message_delta" @@ -24,6 +24,6 @@ class Constants: EVENT_CONTENT_BLOCK_STOP = "content_block_stop" EVENT_CONTENT_BLOCK_DELTA = "content_block_delta" EVENT_PING = "ping" - + DELTA_TEXT = "text_delta" - DELTA_INPUT_JSON = "input_json_delta" \ No newline at end of file + DELTA_INPUT_JSON = "input_json_delta" diff --git a/src/core/logging.py b/src/core/logging.py index 87376bb9..3de386a1 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -5,17 +5,17 @@ log_level = config.log_level.split()[0].upper() # Validate and set default if invalid -valid_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] +valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] if log_level not in valid_levels: - log_level = 'INFO' + log_level = "INFO" # Logging Configuration logging.basicConfig( level=getattr(logging, log_level), - format='%(asctime)s - %(levelname)s - %(message)s', + format="%(asctime)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Configure uvicorn to be quieter for uvicorn_logger in ["uvicorn", "uvicorn.access", "uvicorn.error"]: - logging.getLogger(uvicorn_logger).setLevel(logging.WARNING) \ No newline at end of file + logging.getLogger(uvicorn_logger).setLevel(logging.WARNING) diff --git a/src/core/model_manager.py b/src/core/model_manager.py index 5495f317..5202b192 100644 --- a/src/core/model_manager.py +++ b/src/core/model_manager.py @@ -1,9 +1,10 @@ from src.core.config import config + class ModelManager: def __init__(self, config): self.config = config - + def map_claude_model_to_openai(self, claude_model: str) -> str: """Map Claude model names to OpenAI model names based on BIG/SMALL pattern""" # If it's already an OpenAI model, return as-is @@ -11,20 +12,24 @@ def map_claude_model_to_openai(self, claude_model: str) -> str: return claude_model # If it's other supported models (ARK/Doubao/DeepSeek), return as-is - if (claude_model.startswith("ep-") or claude_model.startswith("doubao-") or - claude_model.startswith("deepseek-")): + if ( + claude_model.startswith("ep-") + or claude_model.startswith("doubao-") + or claude_model.startswith("deepseek-") + ): return claude_model - + # Map based on model naming patterns model_lower = claude_model.lower() - if 'haiku' in model_lower: + if "haiku" in model_lower: return self.config.small_model - elif 'sonnet' in model_lower: + elif "sonnet" in model_lower: return self.config.middle_model - elif 'opus' in model_lower: + elif "opus" in model_lower: return self.config.big_model else: # Default to big model for unknown models return self.config.big_model -model_manager = ModelManager(config) \ No newline at end of file + +model_manager = ModelManager(config) diff --git a/src/main.py b/src/main.py index 8f1e0f34..b52a118e 100644 --- a/src/main.py +++ b/src/main.py @@ -21,9 +21,7 @@ def main(): print("Optional environment variables:") print(" ANTHROPIC_API_KEY - Expected Anthropic API key for client validation") print(" If set, clients must provide this exact API key") - print( - f" OPENAI_BASE_URL - OpenAI API base URL (default: https://api.openai.com/v1)" - ) + print(f" OPENAI_BASE_URL - OpenAI API base URL (default: https://api.openai.com/v1)") print(f" BIG_MODEL - Model for opus requests (default: gpt-4o)") print(f" MIDDLE_MODEL - Model for sonnet requests (default: gpt-4o)") print(f" SMALL_MODEL - Model for haiku requests (default: gpt-4o-mini)") @@ -54,11 +52,11 @@ def main(): # Parse log level - extract just the first word to handle comments log_level = config.log_level.split()[0].lower() - + # Validate and set default if invalid - valid_levels = ['debug', 'info', 'warning', 'error', 'critical'] + valid_levels = ["debug", "info", "warning", "error", "critical"] if log_level not in valid_levels: - log_level = 'info' + log_level = "info" # Start server uvicorn.run( diff --git a/src/models/claude.py b/src/models/claude.py index 91caff0c..5f844cc9 100644 --- a/src/models/claude.py +++ b/src/models/claude.py @@ -1,41 +1,60 @@ from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional, Union, Literal + class ClaudeContentBlockText(BaseModel): type: Literal["text"] text: str + class ClaudeContentBlockImage(BaseModel): type: Literal["image"] source: Dict[str, Any] + class ClaudeContentBlockToolUse(BaseModel): type: Literal["tool_use"] id: str name: str input: Dict[str, Any] + class ClaudeContentBlockToolResult(BaseModel): type: Literal["tool_result"] tool_use_id: str content: Union[str, List[Dict[str, Any]], Dict[str, Any]] + class ClaudeSystemContent(BaseModel): type: Literal["text"] text: str + class ClaudeMessage(BaseModel): role: Literal["user", "assistant"] - content: Union[str, List[Union[ClaudeContentBlockText, ClaudeContentBlockImage, ClaudeContentBlockToolUse, ClaudeContentBlockToolResult]]] + content: Union[ + str, + List[ + Union[ + ClaudeContentBlockText, + ClaudeContentBlockImage, + ClaudeContentBlockToolUse, + ClaudeContentBlockToolResult, + ] + ], + ] + class ClaudeTool(BaseModel): name: str description: Optional[str] = None input_schema: Dict[str, Any] + class ClaudeThinkingConfig(BaseModel): enabled: bool = True + class ClaudeMessagesRequest(BaseModel): model: str max_tokens: int @@ -51,6 +70,7 @@ class ClaudeMessagesRequest(BaseModel): tool_choice: Optional[Dict[str, Any]] = None thinking: Optional[ClaudeThinkingConfig] = None + class ClaudeTokenCountRequest(BaseModel): model: str messages: List[ClaudeMessage] @@ -58,3 +78,10 @@ class ClaudeTokenCountRequest(BaseModel): tools: Optional[List[ClaudeTool]] = None thinking: Optional[ClaudeThinkingConfig] = None tool_choice: Optional[Dict[str, Any]] = None + + +class EventLoggingBatchResponse(BaseModel): + success: bool + batch_id: str + processed_count: int + message: str diff --git a/start_proxy.py b/start_proxy.py index 713708cd..5f3ae85c 100644 --- a/start_proxy.py +++ b/start_proxy.py @@ -5,9 +5,9 @@ import os # Add src to Python path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) from src.main import main if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/test_cancellation.py b/test_cancellation.py index f4d7da7c..01808728 100644 --- a/test_cancellation.py +++ b/test_cancellation.py @@ -6,13 +6,12 @@ import asyncio import httpx -import json -import time + async def test_non_streaming_cancellation(): """Test cancellation for non-streaming requests.""" print("๐Ÿงช Testing non-streaming request cancellation...") - + async with httpx.AsyncClient(timeout=30) as client: try: # Start a long-running request @@ -23,29 +22,33 @@ async def test_non_streaming_cancellation(): "model": "claude-3-5-sonnet-20241022", "max_tokens": 1000, "messages": [ - {"role": "user", "content": "Write a very long story about a journey through space that takes at least 500 words."} - ] - } + { + "role": "user", + "content": "Write a very long story about a journey through space that takes at least 500 words.", + } + ], + }, ) ) - + # Cancel after 2 seconds await asyncio.sleep(2) task.cancel() - + try: await task print("โŒ Request should have been cancelled") except asyncio.CancelledError: print("โœ… Non-streaming request cancelled successfully") - + except Exception as e: print(f"โŒ Non-streaming test error: {e}") + async def test_streaming_cancellation(): """Test cancellation for streaming requests.""" print("\n๐Ÿงช Testing streaming request cancellation...") - + async with httpx.AsyncClient(timeout=30) as client: try: # Start streaming request @@ -56,37 +59,41 @@ async def test_streaming_cancellation(): "model": "claude-3-5-sonnet-20241022", "max_tokens": 1000, "messages": [ - {"role": "user", "content": "Write a very long story about a journey through space that takes at least 500 words."} + { + "role": "user", + "content": "Write a very long story about a journey through space that takes at least 500 words.", + } ], - "stream": True - } + "stream": True, + }, ) as response: if response.status_code == 200: print("โœ… Streaming request started successfully") - + # Read a few chunks then simulate client disconnect chunk_count = 0 async for line in response.aiter_lines(): if line.strip(): chunk_count += 1 print(f"๐Ÿ“ฆ Received chunk {chunk_count}: {line[:100]}...") - + # Simulate client disconnect after 3 chunks if chunk_count >= 3: print("๐Ÿ”Œ Simulating client disconnect...") break - + print("โœ… Streaming request cancelled successfully") else: print(f"โŒ Streaming request failed: {response.status_code}") - + except Exception as e: print(f"โŒ Streaming test error: {e}") + async def test_server_running(): """Test if the server is running.""" print("๐Ÿ” Checking if server is running...") - + try: async with httpx.AsyncClient(timeout=5) as client: response = await client.get("http://localhost:8082/health") @@ -101,23 +108,24 @@ async def test_server_running(): print("๐Ÿ’ก Make sure to start the server with: python start_proxy.py") return False + async def main(): """Main test function.""" print("๐Ÿš€ Starting HTTP request cancellation tests") print("=" * 50) - + # Check if server is running if not await test_server_running(): return - + print("\n" + "=" * 50) - + # Test non-streaming cancellation await test_non_streaming_cancellation() - - # Test streaming cancellation + + # Test streaming cancellation await test_streaming_cancellation() - + print("\n" + "=" * 50) print("โœ… All cancellation tests completed!") print("\n๐Ÿ’ก Note: The actual cancellation behavior depends on:") @@ -126,5 +134,6 @@ async def main(): print(" - Server response to client disconnection") print(" - Whether the underlying OpenAI API supports cancellation") + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/tests/test_main.py b/tests/test_main.py index 3f8212db..f0046717 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2,11 +2,71 @@ import asyncio import json + import httpx +import pytest from dotenv import load_dotenv load_dotenv() +pytestmark = pytest.mark.asyncio + + +async def test_event_batch_list(): + """Test batch logging of event list.""" + async with httpx.AsyncClient() as client: + response = await client.post( + "http://localhost:8082/api/event_logging/batch", + json={ + "batch_id": "batch_test_123", + "events": [ + {"event_type": "message_start"}, + {"event_type": "message_stop"}, + ], + }, + ) + + assert response.status_code == 200 + assert response.json() == { + "success": True, + "batch_id": "batch_test_123", + "processed_count": 2, + "message": "Events logged successfully", + } + + +async def test_event_batch_single(): + """Test batch wrapping of single event.""" + async with httpx.AsyncClient() as client: + response = await client.post( + "http://localhost:8082/api/event_logging/batch", + json={"events": {"event_type": "message_delta"}}, + ) + + body = response.json() + assert response.status_code == 200 + assert body["success"] is True + assert body["batch_id"].startswith("batch_") + assert body["processed_count"] == 1 + assert body["message"] == "Events logged successfully" + + +async def test_event_batch_invalid_json(): + """Test batch handling of invalid JSON.""" + async with httpx.AsyncClient() as client: + response = await client.post( + "http://localhost:8082/api/event_logging/batch", + content="not-json", + headers={"content-type": "application/json"}, + ) + + body = response.json() + assert response.status_code == 200 + assert body["success"] is True + assert body["batch_id"].startswith("batch_") + assert body["processed_count"] == 0 + assert body["message"] == "Events logged successfully" + async def test_basic_chat(): """Test basic chat completion.""" @@ -16,12 +76,10 @@ async def test_basic_chat(): json={ "model": "claude-3-5-sonnet-20241022", "max_tokens": 100, - "messages": [ - {"role": "user", "content": "Hello, how are you?"} - ] - } + "messages": [{"role": "user", "content": "Hello, how are you?"}], + }, ) - + print("Basic chat response:") print(json.dumps(response.json(), indent=2)) @@ -35,11 +93,9 @@ async def test_streaming_chat(): json={ "model": "claude-3-5-haiku-20241022", "max_tokens": 150, - "messages": [ - {"role": "user", "content": "Tell me a short joke"} - ], - "stream": True - } + "messages": [{"role": "user", "content": "Tell me a short joke"}], + "stream": True, + }, ) as response: print("\nStreaming response:") async for line in response.aiter_lines(): @@ -56,7 +112,10 @@ async def test_function_calling(): "model": "claude-3-5-sonnet-20241022", "max_tokens": 200, "messages": [ - {"role": "user", "content": "What's the weather like in New York? Please use the weather function."} + { + "role": "user", + "content": "What's the weather like in New York? Please use the weather function.", + } ], "tools": [ { @@ -67,22 +126,22 @@ async def test_function_calling(): "properties": { "location": { "type": "string", - "description": "The location to get weather for" + "description": "The location to get weather for", }, "unit": { "type": "string", "enum": ["celsius", "fahrenheit"], - "description": "Temperature unit" - } + "description": "Temperature unit", + }, }, - "required": ["location"] - } + "required": ["location"], + }, } ], - "tool_choice": {"type": "auto"} - } + "tool_choice": {"type": "auto"}, + }, ) - + print("\nFunction calling response:") print(json.dumps(response.json(), indent=2)) @@ -96,12 +155,10 @@ async def test_with_system_message(): "model": "claude-3-5-sonnet-20241022", "max_tokens": 100, "system": "You are a helpful assistant that always responds in haiku format.", - "messages": [ - {"role": "user", "content": "Explain what AI is"} - ] - } + "messages": [{"role": "user", "content": "Explain what AI is"}], + }, ) - + print("\nSystem message response:") print(json.dumps(response.json(), indent=2)) @@ -111,7 +168,7 @@ async def test_multimodal(): async with httpx.AsyncClient() as client: # Sample base64 image (1x1 pixel transparent PNG) sample_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU8PJAAAAASUVORK5CYII=" - + response = await client.post( "http://localhost:8082/v1/messages", json={ @@ -127,15 +184,15 @@ async def test_multimodal(): "source": { "type": "base64", "media_type": "image/png", - "data": sample_image - } - } - ] + "data": sample_image, + }, + }, + ], } - ] - } + ], + }, ) - + print("\nMultimodal response:") print(json.dumps(response.json(), indent=2)) @@ -161,26 +218,28 @@ async def test_conversation_with_tool_use(): "properties": { "expression": { "type": "string", - "description": "Mathematical expression to calculate" + "description": "Mathematical expression to calculate", } }, - "required": ["expression"] - } + "required": ["expression"], + }, } - ] - } + ], + }, ) - + print("\nTool call response:") result1 = response1.json() print(json.dumps(result1, indent=2)) - + # Simulate tool execution and send result if result1.get("content"): - tool_use_blocks = [block for block in result1["content"] if block.get("type") == "tool_use"] + tool_use_blocks = [ + block for block in result1["content"] if block.get("type") == "tool_use" + ] if tool_use_blocks: tool_block = tool_use_blocks[0] - + # Second message with tool result response2 = await client.post( "http://localhost:8082/v1/messages", @@ -188,7 +247,10 @@ async def test_conversation_with_tool_use(): "model": "claude-3-5-sonnet-20241022", "max_tokens": 100, "messages": [ - {"role": "user", "content": "Calculate 25 * 4 using the calculator tool"}, + { + "role": "user", + "content": "Calculate 25 * 4 using the calculator tool", + }, {"role": "assistant", "content": result1["content"]}, { "role": "user", @@ -196,14 +258,14 @@ async def test_conversation_with_tool_use(): { "type": "tool_result", "tool_use_id": tool_block["id"], - "content": "100" + "content": "100", } - ] - } - ] - } + ], + }, + ], + }, ) - + print("\nTool result response:") print(json.dumps(response2.json(), indent=2)) @@ -217,10 +279,10 @@ async def test_token_counting(): "model": "claude-3-5-sonnet-20241022", "messages": [ {"role": "user", "content": "This is a test message for token counting."} - ] - } + ], + }, ) - + print("\nToken count response:") print(json.dumps(response.json(), indent=2)) @@ -232,7 +294,7 @@ async def test_health_and_connection(): health_response = await client.get("http://localhost:8082/health") print("\nHealth check:") print(json.dumps(health_response.json(), indent=2)) - + # Connection test connection_response = await client.get("http://localhost:8082/test-connection") print("\nConnection test:") @@ -243,8 +305,11 @@ async def main(): """Run all tests.""" print("๐Ÿงช Testing Claude to OpenAI Proxy") print("=" * 50) - + try: + await test_event_batch_list() + await test_event_batch_single() + await test_event_batch_invalid_json() await test_health_and_connection() await test_token_counting() await test_basic_chat() @@ -253,13 +318,13 @@ async def main(): await test_multimodal() await test_function_calling() await test_conversation_with_tool_use() - + print("\nโœ… All tests completed!") - + except Exception as e: print(f"\nโŒ Test failed: {e}") print("Make sure the server is running with a valid OPENAI_API_KEY") if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main())