From 5e418cd9774779e9d132d0df53bef1d6923d374b Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Mon, 13 Jan 2025 12:38:10 +0000 Subject: [PATCH] Introduce new LLM client architecture Begin migration away from LiteLLM with a modular design: - Add new llmclient package with provider interface - Create bridge implementation to gradually migrate from LiteLLM - Simplify LLMClient class to use new provider system - Add type definitions for requests/responses This change sets up the foundation for removing the LiteLLM dependency while maintaining backward compatibility. Providers will be migrated incrementally in follow-up changes. I am trying to do this without stirring up the regression shit pot too much. Happy for others to recommend other approaches --- src/codegate/llm_utils/llmclient.py | 156 ++++++----------- src/codegate/llmclient/base.py | 94 +++++++++++ .../llmclient/normalizers/__init__.py | 0 src/codegate/llmclient/normalizers/base.py | 42 +++++ src/codegate/llmclient/providers/anthropic.py | 138 +++++++++++++++ src/codegate/llmclient/providers/ollama.py | 157 ++++++++++++++++++ src/codegate/llmclient/providers/openai.py | 149 +++++++++++++++++ src/codegate/llmclient/types.py | 55 ++++++ src/codegate/providers/litellmshim/bridge.py | 150 +++++++++++++++++ tests/test_llmclient_base.py | 155 +++++++++++++++++ tests/test_llmclient_types.py | 137 +++++++++++++++ 11 files changed, 1125 insertions(+), 108 deletions(-) create mode 100644 src/codegate/llmclient/base.py create mode 100644 src/codegate/llmclient/normalizers/__init__.py create mode 100644 src/codegate/llmclient/normalizers/base.py create mode 100644 src/codegate/llmclient/providers/anthropic.py create mode 100644 src/codegate/llmclient/providers/ollama.py create mode 100644 src/codegate/llmclient/providers/openai.py create mode 100644 src/codegate/llmclient/types.py create mode 100644 src/codegate/providers/litellmshim/bridge.py create mode 100644 tests/test_llmclient_base.py create mode 100644 tests/test_llmclient_types.py diff --git a/src/codegate/llm_utils/llmclient.py b/src/codegate/llm_utils/llmclient.py index 53c77e0a..a4766453 100644 --- a/src/codegate/llm_utils/llmclient.py +++ b/src/codegate/llm_utils/llmclient.py @@ -1,26 +1,32 @@ import json from typing import Any, Dict, Optional -import litellm import structlog -from litellm import acompletion -from ollama import Client as OllamaClient from codegate.config import Config from codegate.inference import LlamaCppInferenceEngine +from codegate.llmclient.base import Message, LLMProvider +from codegate.providers.litellmshim.bridge import LiteLLMBridgeProvider logger = structlog.get_logger("codegate") -litellm.drop_params = True - - class LLMClient: - """ - Base class for LLM interactions handling both local and cloud providers. - - This is a kludge before we refactor our providers a bit to be able to pass - in all the parameters we need. - """ + """Base class for LLM interactions handling both local and cloud providers.""" + + @staticmethod + def _create_provider( + provider: str, + model: str = None, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + ) -> Optional[LLMProvider]: + if provider == "llamacpp": + return None # Handled separately for now + return LiteLLMBridgeProvider( + api_key=api_key or "", + base_url=base_url, + default_model=model + ) @staticmethod async def complete( @@ -33,42 +39,41 @@ async def complete( extra_headers: Optional[Dict[str, str]] = None, **kwargs, ) -> Dict[str, Any]: - """ - Send a completion request to either local or cloud LLM. - - Args: - content: The user message content - system_prompt: The system prompt to use - provider: "local" or "litellm" - model: Model identifier - api_key: API key for cloud providers - base_url: Base URL for cloud providers - **kwargs: Additional arguments for the completion request - - Returns: - Parsed response from the LLM - """ if provider == "llamacpp": return await LLMClient._complete_local(content, system_prompt, model, **kwargs) - return await LLMClient._complete_litellm( - content, - system_prompt, - provider, - model, - api_key, - base_url, - extra_headers, - **kwargs, - ) + + llm_provider = LLMClient._create_provider(provider, model, api_key, base_url) + + try: + messages = [ + Message(role="system", content=system_prompt), + Message(role="user", content=content) + ] + + response = await llm_provider.chat( + messages=messages, + temperature=kwargs.get("temperature", 0), + stream=False, + extra_headers=extra_headers, + **kwargs + ) + + return json.loads(response.message.content) + + except Exception as e: + logger.error(f"LLM completion failed {model} ({content}): {e}") + raise e + finally: + await llm_provider.close() @staticmethod - async def _create_request( - content: str, system_prompt: str, model: str, **kwargs + async def _complete_local( + content: str, + system_prompt: str, + model: str, + **kwargs, ) -> Dict[str, Any]: - """ - Private method to create a request dictionary for LLM completion. - """ - return { + request = { "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": content}, @@ -79,16 +84,6 @@ async def _create_request( "temperature": kwargs.get("temperature", 0), } - @staticmethod - async def _complete_local( - content: str, - system_prompt: str, - model: str, - **kwargs, - ) -> Dict[str, Any]: - # Use the private method to create the request - request = await LLMClient._create_request(content, system_prompt, model, **kwargs) - inference_engine = LlamaCppInferenceEngine() result = await inference_engine.chat( f"{Config.get_config().model_base_path}/{request['model']}.gguf", @@ -98,58 +93,3 @@ async def _complete_local( ) return json.loads(result["choices"][0]["message"]["content"]) - - @staticmethod - async def _complete_litellm( - content: str, - system_prompt: str, - provider: str, - model: str, - api_key: str, - base_url: Optional[str] = None, - extra_headers: Optional[Dict[str, str]] = None, - **kwargs, - ) -> Dict[str, Any]: - # Use the private method to create the request - request = await LLMClient._create_request(content, system_prompt, model, **kwargs) - - # We should reuse the same logic in the provider - # but let's do that later - if provider == "vllm": - if not base_url.endswith("/v1"): - base_url = f"{base_url}/v1" - else: - if not model.startswith(f"{provider}/"): - model = f"{provider}/{model}" - - try: - if provider == "ollama": - model = model.split("/")[-1] - response = OllamaClient(host=base_url).chat( - model=model, - messages=request["messages"], - format="json", - options={"temperature": request["temperature"]}, - ) - content = response.message.content - else: - response = await acompletion( - model=model, - messages=request["messages"], - api_key=api_key, - temperature=request["temperature"], - base_url=base_url, - response_format=request["response_format"], - extra_headers=extra_headers, - ) - content = response["choices"][0]["message"]["content"] - - # Clean up code blocks if present - if content.startswith("```"): - content = content.split("\n", 1)[1].rsplit("```", 1)[0].strip() - - return json.loads(content) - - except Exception as e: - logger.error(f"LiteLLM completion failed {model} ({content}): {e}") - raise e diff --git a/src/codegate/llmclient/base.py b/src/codegate/llmclient/base.py new file mode 100644 index 00000000..5f7448a9 --- /dev/null +++ b/src/codegate/llmclient/base.py @@ -0,0 +1,94 @@ +from abc import ABC, abstractmethod +from typing import AsyncIterator, Dict, List, Optional, Union +from dataclasses import dataclass + +@dataclass +class Message: + """Represents a chat message.""" + role: str + content: str + +@dataclass +class CompletionResponse: + """Represents a completion response from an LLM.""" + text: str + model: str + usage: Dict[str, int] + +@dataclass +class ChatResponse: + """Represents a chat response from an LLM.""" + message: Message + model: str + usage: Dict[str, int] + +class LLMProvider(ABC): + """Abstract base class for LLM providers.""" + + def __init__( + self, + api_key: str, + base_url: Optional[str] = None, + default_model: Optional[str] = None + ): + """Initialize the LLM provider. + + Args: + api_key: API key for authentication + base_url: Optional custom base URL for the API + default_model: Optional default model to use + """ + self.api_key = api_key + self.base_url = base_url + self.default_model = default_model + + @abstractmethod + async def chat( + self, + messages: List[Message], + model: Optional[str] = None, + temperature: float = 0.7, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncIterator[ChatResponse]]: + """Send a chat request to the LLM. + + Args: + messages: List of messages in the conversation + model: Optional model override + temperature: Sampling temperature + stream: Whether to stream the response + **kwargs: Additional provider-specific parameters + + Returns: + ChatResponse or AsyncIterator[ChatResponse] if streaming + """ + pass + + @abstractmethod + async def complete( + self, + prompt: str, + model: Optional[str] = None, + temperature: float = 0.7, + stream: bool = False, + **kwargs + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponse]]: + """Send a completion request to the LLM. + + Args: + prompt: The text prompt + model: Optional model override + temperature: Sampling temperature + stream: Whether to stream the response + **kwargs: Additional provider-specific parameters + + Returns: + CompletionResponse or AsyncIterator[CompletionResponse] if streaming + """ + pass + + @abstractmethod + async def close(self) -> None: + """Close any open connections.""" + pass \ No newline at end of file diff --git a/src/codegate/llmclient/normalizers/__init__.py b/src/codegate/llmclient/normalizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/codegate/llmclient/normalizers/base.py b/src/codegate/llmclient/normalizers/base.py new file mode 100644 index 00000000..29a79ca4 --- /dev/null +++ b/src/codegate/llmclient/normalizers/base.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod +from typing import Any, AsyncIterator, Dict + +from ..types import Message, NormalizedRequest, ChatResponse + +class ModelInputNormalizer(ABC): + @abstractmethod + def normalize(self, data: Dict[str, Any]) -> NormalizedRequest: + """Convert provider-specific request format to SimpleModelRouter format.""" + pass + + @abstractmethod + def denormalize(self, data: NormalizedRequest) -> Dict[str, Any]: + """Convert SimpleModelRouter format back to provider-specific request format.""" + pass + +class ModelOutputNormalizer(ABC): + @abstractmethod + def normalize_streaming( + self, + model_reply: AsyncIterator[Any] + ) -> AsyncIterator[ChatResponse]: + """Convert provider-specific streaming response to SimpleModelRouter format.""" + pass + + @abstractmethod + def normalize(self, model_reply: Any) -> ChatResponse: + """Convert provider-specific response to SimpleModelRouter format.""" + pass + + @abstractmethod + def denormalize(self, normalized_reply: ChatResponse) -> Dict[str, Any]: + """Convert SimpleModelRouter format back to provider-specific response format.""" + pass + + @abstractmethod + def denormalize_streaming( + self, + normalized_reply: AsyncIterator[ChatResponse] + ) -> AsyncIterator[Any]: + """Convert SimpleModelRouter streaming response back to provider-specific format.""" + pass diff --git a/src/codegate/llmclient/providers/anthropic.py b/src/codegate/llmclient/providers/anthropic.py new file mode 100644 index 00000000..bbc67b4e --- /dev/null +++ b/src/codegate/llmclient/providers/anthropic.py @@ -0,0 +1,138 @@ +import httpx +from typing import AsyncIterator, Dict, List, Optional, Union +import json + +from codegate.llmclient.base import LLMProvider, Message, ChatResponse, CompletionResponse + +class AnthropicProvider(LLMProvider): + """Anthropic API provider implementation.""" + + def __init__( + self, + api_key: str, + base_url: Optional[str] = "https://api.anthropic.com/v1", + default_model: Optional[str] = "claude-3-opus-20240229" + ): + """Initialize the Anthropic provider. + + Args: + api_key: Anthropic API key + base_url: Optional API base URL override + default_model: Default model to use + """ + super().__init__(api_key, base_url, default_model) + self._client = httpx.AsyncClient( + base_url=self.base_url, + headers={ + "x-api-key": self.api_key, + "anthropic-version": "2023-06-01", + "Content-Type": "application/json" + } + ) + + async def chat( + self, + messages: List[Message], + model: Optional[str] = None, + temperature: float = 0.7, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncIterator[ChatResponse]]: + """Send a chat request to Anthropic.""" + model = model or self.default_model + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in messages], + "temperature": temperature, + "stream": stream, + **kwargs + } + + if stream: + return self._stream_chat(payload) + + async with self._client as client: + response = await client.post("/messages", json=payload) + response.raise_for_status() + data = response.json() + + return ChatResponse( + message=Message( + role="assistant", + content=data["content"][0]["text"] + ), + model=data["model"], + usage={ + "prompt_tokens": data.get("usage", {}).get("input_tokens", 0), + "completion_tokens": data.get("usage", {}).get("output_tokens", 0), + "total_tokens": data.get("usage", {}).get("total_tokens", 0) + } + ) + + async def complete( + self, + prompt: str, + model: Optional[str] = None, + temperature: float = 0.7, + stream: bool = False, + **kwargs + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponse]]: + """Send a completion request to Anthropic.""" + # Convert completion request to chat format since Anthropic uses unified endpoint + messages = [Message(role="user", content=prompt)] + chat_response = await self.chat( + messages=messages, + model=model, + temperature=temperature, + stream=stream, + **kwargs + ) + + if stream: + async def convert_stream(): + async for chunk in chat_response: + yield CompletionResponse( + text=chunk.message.content, + model=chunk.model, + usage=chunk.usage + ) + return convert_stream() + else: + return CompletionResponse( + text=chat_response.message.content, + model=chat_response.model, + usage=chat_response.usage + ) + + async def _stream_chat(self, payload: Dict) -> AsyncIterator[ChatResponse]: + """Handle streaming chat responses.""" + async with self._client as client: + async with client.stream("POST", "/messages", json=payload) as response: + response.raise_for_status() + + async for line in response.aiter_lines(): + if line.startswith("data: "): + if line.strip() == "data: [DONE]": + break + + data = json.loads(line[6:]) + if "delta" not in data: + continue + + delta = data["delta"] + if "text" not in delta: + continue + + yield ChatResponse( + message=Message( + role="assistant", + content=delta["text"] + ), + model=data["model"], + usage={} # Usage stats only available at end of stream + ) + + async def close(self) -> None: + """Close the HTTP client.""" + await self._client.aclose() \ No newline at end of file diff --git a/src/codegate/llmclient/providers/ollama.py b/src/codegate/llmclient/providers/ollama.py new file mode 100644 index 00000000..9cc4a428 --- /dev/null +++ b/src/codegate/llmclient/providers/ollama.py @@ -0,0 +1,157 @@ +import httpx +from typing import AsyncIterator, Dict, List, Optional, Union +import json + +from codegate.llmclient.base import LLMProvider, Message, ChatResponse, CompletionResponse + +class OllamaProvider(LLMProvider): + """Ollama API provider implementation.""" + + def __init__( + self, + api_key: str = "", # Ollama doesn't use API keys by default + base_url: Optional[str] = "http://localhost:11434", + default_model: Optional[str] = "llama2" + ): + """Initialize the Ollama provider. + + Args: + api_key: Not used by default in Ollama + base_url: Optional API base URL override + default_model: Default model to use + """ + super().__init__(api_key, base_url, default_model) + self._client = httpx.AsyncClient( + base_url=self.base_url, + headers={"Content-Type": "application/json"} + ) + + async def chat( + self, + messages: List[Message], + model: Optional[str] = None, + temperature: float = 0.7, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncIterator[ChatResponse]]: + """Send a chat request to Ollama.""" + model = model or self.default_model + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in messages], + "stream": stream, + "options": { + "temperature": temperature, + **kwargs + } + } + + if stream: + return self._stream_chat(payload) + + async with self._client as client: + response = await client.post("/api/chat", json=payload) + response.raise_for_status() + data = response.json() + + return ChatResponse( + message=Message( + role="assistant", + content=data["message"]["content"] + ), + model=model, + usage={ + "prompt_tokens": data.get("prompt_eval_count", 0), + "completion_tokens": data.get("eval_count", 0), + "total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0) + } + ) + + async def complete( + self, + prompt: str, + model: Optional[str] = None, + temperature: float = 0.7, + stream: bool = False, + **kwargs + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponse]]: + """Send a completion request to Ollama.""" + model = model or self.default_model + + payload = { + "model": model, + "prompt": prompt, + "stream": stream, + "options": { + "temperature": temperature, + **kwargs + } + } + + if stream: + return self._stream_completion(payload) + + async with self._client as client: + response = await client.post("/api/generate", json=payload) + response.raise_for_status() + data = response.json() + + return CompletionResponse( + text=data["response"], + model=model, + usage={ + "prompt_tokens": data.get("prompt_eval_count", 0), + "completion_tokens": data.get("eval_count", 0), + "total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0) + } + ) + + async def _stream_chat(self, payload: Dict) -> AsyncIterator[ChatResponse]: + """Handle streaming chat responses.""" + async with self._client as client: + async with client.stream("POST", "/api/chat", json=payload) as response: + response.raise_for_status() + + async for line in response.aiter_lines(): + data = json.loads(line) + if "done" in data and data["done"]: + break + + yield ChatResponse( + message=Message( + role="assistant", + content=data["message"]["content"] + ), + model=payload["model"], + usage={ + "prompt_tokens": data.get("prompt_eval_count", 0), + "completion_tokens": data.get("eval_count", 0), + "total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0) + } + ) + + async def _stream_completion(self, payload: Dict) -> AsyncIterator[CompletionResponse]: + """Handle streaming completion responses.""" + async with self._client as client: + async with client.stream("POST", "/api/generate", json=payload) as response: + response.raise_for_status() + + async for line in response.aiter_lines(): + data = json.loads(line) + if "done" in data and data["done"]: + break + + yield CompletionResponse( + text=data["response"], + model=payload["model"], + usage={ + "prompt_tokens": data.get("prompt_eval_count", 0), + "completion_tokens": data.get("eval_count", 0), + "total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0) + } + ) + + async def close(self) -> None: + """Close the HTTP client.""" + await self._client.aclose() \ No newline at end of file diff --git a/src/codegate/llmclient/providers/openai.py b/src/codegate/llmclient/providers/openai.py new file mode 100644 index 00000000..30a93f33 --- /dev/null +++ b/src/codegate/llmclient/providers/openai.py @@ -0,0 +1,149 @@ +import httpx +from typing import AsyncIterator, Dict, List, Optional, Union +import json + +from codegate.llmclient.base import LLMProvider, Message, ChatResponse, CompletionResponse + +class OpenAIProvider(LLMProvider): + """OpenAI API provider implementation.""" + + def __init__( + self, + api_key: str, + base_url: Optional[str] = "https://api.openai.com/v1", + default_model: Optional[str] = "gpt-3.5-turbo" + ): + """Initialize the OpenAI provider. + + Args: + api_key: OpenAI API key + base_url: Optional API base URL override + default_model: Default model to use + """ + super().__init__(api_key, base_url, default_model) + self._client = httpx.AsyncClient( + base_url=self.base_url, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + }, + timeout=60.0 # 60 second timeout for requests + ) + + async def chat( + self, + messages: List[Message], + model: Optional[str] = None, + temperature: float = 0.7, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncIterator[ChatResponse]]: + """Send a chat request to OpenAI.""" + model = model or self.default_model + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in messages], + "temperature": temperature, + "stream": stream, + **kwargs + } + + if stream: + return self._stream_chat(payload) + + response = await self._client.post("/chat/completions", json=payload) + response.raise_for_status() + data = response.json() + + return ChatResponse( + message=Message( + role=data["choices"][0]["message"]["role"], + content=data["choices"][0]["message"]["content"] + ), + model=data["model"], + usage=data["usage"] + ) + + async def complete( + self, + prompt: str, + model: Optional[str] = None, + temperature: float = 0.7, + stream: bool = False, + **kwargs + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponse]]: + """Send a completion request to OpenAI.""" + model = model or self.default_model + + payload = { + "model": model, + "prompt": prompt, + "temperature": temperature, + "stream": stream, + **kwargs + } + + if stream: + return self._stream_completion(payload) + + response = await self._client.post("/completions", json=payload) + response.raise_for_status() + data = response.json() + + return CompletionResponse( + text=data["choices"][0]["text"], + model=data["model"], + usage=data["usage"] + ) + + async def _stream_chat(self, payload: Dict) -> AsyncIterator[ChatResponse]: + """Handle streaming chat responses.""" + async with self._client.stream("POST", "/chat/completions", json=payload) as response: + response.raise_for_status() + + async for line in response.aiter_lines(): + if line.startswith("data: "): + if line.strip() == "data: [DONE]": + break + + data = json.loads(line[6:]) + if not data["choices"]: + continue + + delta = data["choices"][0]["delta"] + if "content" not in delta: + continue + + yield ChatResponse( + message=Message( + role=delta.get("role", "assistant"), + content=delta["content"] + ), + model=data["model"], + usage={} # Usage stats only available at end of stream + ) + + async def _stream_completion(self, payload: Dict) -> AsyncIterator[CompletionResponse]: + """Handle streaming completion responses.""" + async with self._client.stream("POST", "/completions", json=payload) as response: + response.raise_for_status() + + async for line in response.aiter_lines(): + if line.startswith("data: "): + if line.strip() == "data: [DONE]": + break + + data = json.loads(line[6:]) + if not data["choices"]: + continue + + yield CompletionResponse( + text=data["choices"][0]["text"], + model=data["model"], + usage={} # Usage stats only available at end of stream + ) + + async def close(self) -> None: + """Close the HTTP client.""" + await self._client.aclose() diff --git a/src/codegate/llmclient/types.py b/src/codegate/llmclient/types.py new file mode 100644 index 00000000..b0fe8a2b --- /dev/null +++ b/src/codegate/llmclient/types.py @@ -0,0 +1,55 @@ +from typing import Any, Dict, List, Optional +from pydantic import BaseModel +from dataclasses import dataclass, field + +@dataclass +class Message: + content: str + role: str + +@dataclass +class NormalizedRequest: + messages: List[Message] + model: str + stream: bool = True + options: Dict[str, Any] = field(default_factory=dict) + +class ChatResponse(BaseModel): + id: str + messages: List[Message] + created: int # Unix timestamp + model: str + done: bool = False + + +class Delta(BaseModel): + """Delta represents a change in content for streaming responses.""" + content: Optional[str] = None + role: Optional[str] = None + + +class Choice(BaseModel): + """Choice represents a single completion choice from the model.""" + finish_reason: Optional[str] = None + index: int = 0 + delta: Delta + logprobs: Optional[Any] = None + + +class Response(BaseModel): + """Response represents a model's response to a prompt.""" + id: str + choices: List[Choice] + created: int # Unix timestamp + model: str + object: str = "chat.completion.chunk" + stream: bool = False + + def json(self) -> str: + """Convert the response to a JSON string.""" + return self.model_dump_json(exclude_none=True) + + @property + def message(self) -> Optional[Choice]: + """Get the first choice from the response.""" + return self.choices[0] if self.choices else None \ No newline at end of file diff --git a/src/codegate/providers/litellmshim/bridge.py b/src/codegate/providers/litellmshim/bridge.py new file mode 100644 index 00000000..f22a9a18 --- /dev/null +++ b/src/codegate/providers/litellmshim/bridge.py @@ -0,0 +1,150 @@ +import time +from typing import AsyncIterator, Dict, List, Optional, Union + +import litellm +from litellm import acompletion +from litellm.types.utils import Delta, StreamingChoices + +from codegate.llmclient.base import ( + ChatResponse, + CompletionResponse, + LLMProvider, + Message, +) + +litellm.drop_params = True + +class LiteLLMBridgeProvider(LLMProvider): + """Bridge provider that implements the new LLMProvider interface using LiteLLM.""" + + async def chat( + self, + messages: List[Message], + model: Optional[str] = None, + temperature: float = 0.7, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncIterator[ChatResponse]]: + """Send a chat request using LiteLLM.""" + + # Convert messages to LiteLLM format + litellm_messages = [ + {"role": msg.role, "content": msg.content} + for msg in messages + ] + + # Use default model if none specified + model_name = model or self.default_model + if not model_name: + raise ValueError("No model specified") + + # Prepare request + request = { + "model": model_name, + "messages": litellm_messages, + "temperature": temperature, + "stream": stream, + "api_key": self.api_key, + **kwargs + } + + if self.base_url: + request["base_url"] = self.base_url + + async def process_stream() -> AsyncIterator[ChatResponse]: + async for chunk in await acompletion(**request): + if not chunk.choices: + continue + + choice = chunk.choices[0] + if not choice.delta or not choice.delta.content: + continue + + yield ChatResponse( + message=Message( + role="assistant", + content=choice.delta.content + ), + model=chunk.model, + usage={} # Usage stats only available in final response + ) + + if stream: + return process_stream() + else: + response = await acompletion(**request) + if not response.choices: + raise ValueError("No choices in response") + + choice = response.choices[0] + return ChatResponse( + message=Message( + role=choice.message.role, + content=choice.message.content + ), + model=response.model, + usage=response.usage + ) + + async def complete( + self, + prompt: str, + model: Optional[str] = None, + temperature: float = 0.7, + stream: bool = False, + **kwargs + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponse]]: + """Send a completion request using LiteLLM.""" + # Convert to chat format since that's what most models use now + messages = [{"role": "user", "content": prompt}] + + # Use default model if none specified + model_name = model or self.default_model + if not model_name: + raise ValueError("No model specified") + + # Prepare request + request = { + "model": model_name, + "messages": messages, + "temperature": temperature, + "stream": stream, + "api_key": self.api_key, + **kwargs + } + + if self.base_url: + request["base_url"] = self.base_url + + async def process_stream() -> AsyncIterator[CompletionResponse]: + async for chunk in await acompletion(**request): + if not chunk.choices: + continue + + choice = chunk.choices[0] + if not choice.delta or not choice.delta.content: + continue + + yield CompletionResponse( + text=choice.delta.content, + model=chunk.model, + usage={} # Usage stats only available in final response + ) + + if stream: + return process_stream() + else: + response = await acompletion(**request) + if not response.choices: + raise ValueError("No choices in response") + + choice = response.choices[0] + return CompletionResponse( + text=choice.message.content, + model=response.model, + usage=response.usage + ) + + async def close(self) -> None: + """Nothing to close for LiteLLM.""" + pass \ No newline at end of file diff --git a/tests/test_llmclient_base.py b/tests/test_llmclient_base.py new file mode 100644 index 00000000..9845a24d --- /dev/null +++ b/tests/test_llmclient_base.py @@ -0,0 +1,155 @@ +import pytest +from typing import AsyncIterator, List +from unittest.mock import AsyncMock, MagicMock + +from codegate.llmclient.base import ChatResponse, CompletionResponse, LLMProvider, Message + +class MockLLMProvider(LLMProvider): + """Mock provider for testing the base LLM functionality.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.chat_mock = AsyncMock() + self.complete_mock = AsyncMock() + self.close_mock = AsyncMock() + + async def chat( + self, + messages: List[Message], + model: str = None, + temperature: float = 0.7, + stream: bool = False, + **kwargs + ): + return await self.chat_mock(messages, model, temperature, stream, **kwargs) + + async def complete( + self, + prompt: str, + model: str = None, + temperature: float = 0.7, + stream: bool = False, + **kwargs + ): + return await self.complete_mock(prompt, model, temperature, stream, **kwargs) + + async def close(self): + await self.close_mock() + +@pytest.fixture +def mock_provider(): + return MockLLMProvider(api_key="test-key", base_url="http://test", default_model="test-model") + +@pytest.mark.asyncio +async def test_provider_initialization(): + """Test that provider is initialized with correct parameters.""" + provider = MockLLMProvider( + api_key="test-key", + base_url="http://test", + default_model="test-model" + ) + + assert provider.api_key == "test-key" + assert provider.base_url == "http://test" + assert provider.default_model == "test-model" + +@pytest.mark.asyncio +async def test_chat_non_streaming(mock_provider): + """Test non-streaming chat completion.""" + expected_response = ChatResponse( + message=Message(role="assistant", content="Test response"), + model="test-model", + usage={"prompt_tokens": 10, "completion_tokens": 5} + ) + mock_provider.chat_mock.return_value = expected_response + + messages = [ + Message(role="system", content="You are a test assistant"), + Message(role="user", content="Hello") + ] + + response = await mock_provider.chat(messages) + assert response == expected_response + mock_provider.chat_mock.assert_called_once_with( + messages, None, 0.7, False + ) + +@pytest.mark.asyncio +async def test_chat_streaming(mock_provider): + """Test streaming chat completion.""" + async def mock_stream(): + responses = [ + ChatResponse( + message=Message(role="assistant", content="Test"), + model="test-model", + usage={} + ), + ChatResponse( + message=Message(role="assistant", content=" response"), + model="test-model", + usage={} + ) + ] + for response in responses: + yield response + + mock_provider.chat_mock.return_value = mock_stream() + + messages = [Message(role="user", content="Hello")] + responses = [] + + async for response in await mock_provider.chat(messages, stream=True): + responses.append(response) + + assert len(responses) == 2 + assert responses[0].message.content == "Test" + assert responses[1].message.content == " response" + mock_provider.chat_mock.assert_called_once_with( + messages, None, 0.7, True + ) + +@pytest.mark.asyncio +async def test_complete_non_streaming(mock_provider): + """Test non-streaming text completion.""" + expected_response = CompletionResponse( + text="Test response", + model="test-model", + usage={"prompt_tokens": 5, "completion_tokens": 2} + ) + mock_provider.complete_mock.return_value = expected_response + + response = await mock_provider.complete("Hello") + assert response == expected_response + mock_provider.complete_mock.assert_called_once_with( + "Hello", None, 0.7, False + ) + +@pytest.mark.asyncio +async def test_complete_streaming(mock_provider): + """Test streaming text completion.""" + async def mock_stream(): + responses = [ + CompletionResponse(text="Test", model="test-model", usage={}), + CompletionResponse(text=" response", model="test-model", usage={}) + ] + for response in responses: + yield response + + mock_provider.complete_mock.return_value = mock_stream() + + responses = [] + async for response in await mock_provider.complete("Hello", stream=True): + responses.append(response) + + assert len(responses) == 2 + assert responses[0].text == "Test" + assert responses[1].text == " response" + mock_provider.complete_mock.assert_called_once_with( + "Hello", None, 0.7, True + ) + +@pytest.mark.asyncio +async def test_close(mock_provider): + """Test that close is called.""" + await mock_provider.close() + mock_provider.close_mock.assert_called_once() \ No newline at end of file diff --git a/tests/test_llmclient_types.py b/tests/test_llmclient_types.py new file mode 100644 index 00000000..332b5794 --- /dev/null +++ b/tests/test_llmclient_types.py @@ -0,0 +1,137 @@ +import json +import pytest +from codegate.llmclient.types import ( + Message, + NormalizedRequest, + ChatResponse, + Delta, + Choice, + Response +) + +def test_message(): + """Test Message dataclass.""" + msg = Message(content="Hello", role="user") + assert msg.content == "Hello" + assert msg.role == "user" + +def test_normalized_request(): + """Test NormalizedRequest dataclass.""" + messages = [ + Message(content="Hello", role="user"), + Message(content="Hi", role="assistant") + ] + request = NormalizedRequest( + messages=messages, + model="test-model", + stream=True, + options={"temperature": 0.7} + ) + + assert request.messages == messages + assert request.model == "test-model" + assert request.stream is True + assert request.options == {"temperature": 0.7} + +def test_chat_response(): + """Test ChatResponse model.""" + messages = [Message(content="Hello", role="user")] + response = ChatResponse( + id="test-id", + messages=messages, + created=1234567890, + model="test-model", + done=True + ) + + assert response.id == "test-id" + assert response.messages == messages + assert response.created == 1234567890 + assert response.model == "test-model" + assert response.done is True + +def test_delta(): + """Test Delta model.""" + delta = Delta(content="Hello", role="assistant") + assert delta.content == "Hello" + assert delta.role == "assistant" + + # Test optional fields + delta = Delta() + assert delta.content is None + assert delta.role is None + +def test_choice(): + """Test Choice model.""" + delta = Delta(content="Hello") + choice = Choice( + finish_reason="stop", + index=0, + delta=delta + ) + + assert choice.finish_reason == "stop" + assert choice.index == 0 + assert choice.delta == delta + assert choice.logprobs is None + +def test_response(): + """Test Response model.""" + delta = Delta(content="Hello") + choice = Choice(delta=delta) + response = Response( + id="test-id", + choices=[choice], + created=1234567890, + model="test-model" + ) + + assert response.id == "test-id" + assert len(response.choices) == 1 + assert response.created == 1234567890 + assert response.model == "test-model" + assert response.object == "chat.completion.chunk" + assert response.stream is False + +def test_response_json(): + """Test Response JSON serialization.""" + delta = Delta(content="Hello") + choice = Choice(delta=delta) + response = Response( + id="test-id", + choices=[choice], + created=1234567890, + model="test-model" + ) + + json_str = response.json() + data = json.loads(json_str) + + assert data["id"] == "test-id" + assert len(data["choices"]) == 1 + assert data["created"] == 1234567890 + assert data["model"] == "test-model" + assert data["object"] == "chat.completion.chunk" + assert data["stream"] is False + +def test_response_message(): + """Test Response message property.""" + delta = Delta(content="Hello") + choice = Choice(delta=delta) + response = Response( + id="test-id", + choices=[choice], + created=1234567890, + model="test-model" + ) + + assert response.message == choice + + # Test empty choices + empty_response = Response( + id="test-id", + choices=[], + created=1234567890, + model="test-model" + ) + assert empty_response.message is None \ No newline at end of file