-
Notifications
You must be signed in to change notification settings - Fork 655
Expand file tree
/
Copy pathllm_utils.py
More file actions
62 lines (50 loc) · 2.03 KB
/
Copy pathllm_utils.py
File metadata and controls
62 lines (50 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
Utility functions for LLM providers.
"""
import logging
from typing import Any, Dict, Optional
from models import ModelProvider, OllamaProvider, GeminiProvider
from prompt import MODEL_PROVIDER_MAPPING, GEMINI_API_KEY
logger = logging.getLogger(__name__)
def extract_json_from_response(response_text: str) -> str:
"""
Extract JSON content from markdown code blocks.
Args:
response_text: Text that may contain JSON wrapped in markdown code blocks
Returns:
Text with markdown code block syntax removed
"""
response_text = response_text.strip()
if "<think>" in response_text:
think_start = response_text.find("<think>")
think_end = response_text.find("</think>")
if think_start != -1 and think_end != -1:
response_text = response_text[:think_start] + response_text[think_end + 8 :]
# Remove leading ```json if present
if response_text.startswith("```json"):
response_text = response_text[7:]
# Remove trailing ``` if present
if response_text.endswith("```"):
response_text = response_text[:-3]
return response_text
def initialize_llm_provider(model_name: str) -> Any:
"""
Initialize the appropriate LLM provider based on the model name.
Args:
model_name: The name of the model to use
Returns:
An initialized LLM provider (either OllamaProvider or GeminiProvider)
"""
# Default to Ollama provider
provider = OllamaProvider()
# If using Gemini and API key is available, use Gemini provider
model_provider = MODEL_PROVIDER_MAPPING.get(model_name, ModelProvider.OLLAMA)
if model_provider == ModelProvider.GEMINI:
if not GEMINI_API_KEY:
logger.warning("⚠️ Gemini API key not found. Falling back to Ollama.")
else:
logger.info(f"🔄 Using Google Gemini API provider with model {model_name}")
provider = GeminiProvider(api_key=GEMINI_API_KEY)
else:
logger.info(f"🔄 Using Ollama provider with model {model_name}")
return provider