Skip to content

Commit e93ccf8

Browse files
authored
Merge branch 'vllm' into main
2 parents 08095f8 + bc734f9 commit e93ccf8

4 files changed

Lines changed: 303 additions & 1 deletion

File tree

easyroutine/inference/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from easyroutine.inference.base_model_interface import BaseInferenceModelConfig
2-
from easyroutine.inference.vllm_model_interface import VLLMInferenceModel, VLLMInferenceModelConfig
2+
from easyroutine.inference.vllm_model_interface import VLLMInferenceModel, VLLMInferenceModelConfig
3+
from easyroutine.inference.litellm_model_interface import LiteLLMInferenceModel, LiteLLMInferenceModelConfig
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from easyroutine.inference.base_model_interface import BaseInferenceModel, BaseInferenceModelConfig
2+
from vllm import LLM, SamplingParams
3+
from typing import Union, List, Literal
4+
from dataclasses import dataclass
5+
from litellm import completion, batch_completion
6+
7+
@dataclass
8+
class LiteLLMInferenceModelConfig(BaseInferenceModelConfig):
9+
"""just a placeholder for now, as we don't have any specific config for VLLM."""
10+
model_name: str
11+
12+
n_gpus: int = 0
13+
dtype: str = 'bfloat16'
14+
temperature: float = 0
15+
top_p: float = 0.95
16+
max_new_tokens: int = 5000
17+
18+
openai_api_key: str = ''
19+
anthropic_api_key: str = ''
20+
xai_api_key: str = ''
21+
22+
class LiteLLMInferenceModel(BaseInferenceModel):
23+
24+
def __init__(self, config: LiteLLMInferenceModelConfig):
25+
self.config = config
26+
self.set_os_env()
27+
28+
def set_os_env(self):
29+
import os
30+
os.environ['OPENAI_API_KEY'] = self.config.openai_api_key
31+
os.environ['ANTHROPIC_API_KEY'] = self.config.anthropic_api_key
32+
os.environ['XAI_API_KEY'] = self.config.xai_api_key
33+
34+
def convert_chat_messages_to_custom_format(self, chat_messages: List[dict[str, str]]) -> List[dict[str, str]]:
35+
"""
36+
For now, VLLM is compatible with the chat template format we use.
37+
"""
38+
return chat_messages
39+
40+
def chat(self, chat_messages: List[dict[str, str]], use_tqdm=False, **kwargs) -> list:
41+
"""
42+
Generate a response based on the provided chat messages.
43+
44+
Arguments:
45+
chat_messages (List[dict[str, str]]): List of chat messages to process.
46+
**kwargs: Additional parameters for the model.
47+
48+
Returns:
49+
str: The generated response from the model.
50+
"""
51+
chat_messages = self.convert_chat_messages_to_custom_format(chat_messages)
52+
53+
54+
response = completion(
55+
model = self.config.model_name,
56+
messages = chat_messages,
57+
temperature = self.config.temperature,
58+
top_p = self.config.top_p,
59+
max_tokens = self.config.max_new_tokens,
60+
)
61+
return response['choices']
62+
63+
def batch_chat(self, chat_messages: List[List[dict[str, str]]], use_tqdm=False, **kwargs) -> List[list]:
64+
"""
65+
Generate responses for a batch of chat messages.
66+
67+
Arguments:
68+
chat_messages (List[List[dict[str, str]]]): List of chat messages to process.
69+
**kwargs: Additional parameters for the model.
70+
71+
Returns:
72+
List[list]: List of generated responses from the model.
73+
"""
74+
chat_messages = [self.convert_chat_messages_to_custom_format(msg) for msg in chat_messages]
75+
76+
responses = batch_completion(
77+
model = self.config.model_name,
78+
messages = chat_messages,
79+
temperature = self.config.temperature,
80+
top_p = self.config.top_p,
81+
max_tokens = self.config.max_new_tokens,
82+
)
83+
return responses

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ pyyaml = "^6.0.2"
2424
sentencepiece = "^0.2.0"
2525
transformers = "4.51.1"
2626
pydantic = "^2.10.6"
27+
litellm = "^1.74.0"

tutorial/5_inference.ipynb

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "2e41c136",
7+
"metadata": {},
8+
"outputs": [
9+
{
10+
"name": "stdout",
11+
"output_type": "stream",
12+
"text": [
13+
"Changed working directory to: /home/francesco/HistoryRevisionismLLM\n"
14+
]
15+
}
16+
],
17+
"source": [
18+
"from easyroutine import path_to_parents\n",
19+
"path_to_parents(2)\n",
20+
"\n",
21+
"%load_ext autoreload\n",
22+
"%autoreload 2"
23+
]
24+
},
25+
{
26+
"cell_type": "markdown",
27+
"id": "3f0911d1",
28+
"metadata": {},
29+
"source": [
30+
"# Inference Module\n",
31+
"`easyroutine` provide a simple interface to interact with various LLMs using different backends. Specifically, it supports:\n",
32+
"- **vLLM**: A high-performance inference engine for large language models running on GPUs.\n",
33+
"- **LiteLLM**: A lightweight interface for OpenAI, Anthropic, and XAI APIs.\n"
34+
]
35+
},
36+
{
37+
"cell_type": "markdown",
38+
"id": "246e7dea",
39+
"metadata": {},
40+
"source": [
41+
"## LiteLLM Inference Model"
42+
]
43+
},
44+
{
45+
"cell_type": "markdown",
46+
"id": "44a974d9",
47+
"metadata": {},
48+
"source": [
49+
"First load the api keys from the `.env` file:\n"
50+
]
51+
},
52+
{
53+
"cell_type": "code",
54+
"execution_count": 2,
55+
"id": "239a674d",
56+
"metadata": {},
57+
"outputs": [],
58+
"source": [
59+
"\n",
60+
"from dotenv import load_dotenv\n",
61+
"load_dotenv()\n",
62+
"#get the openai api key from the .env file\n",
63+
"import os\n",
64+
"OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')"
65+
]
66+
},
67+
{
68+
"cell_type": "markdown",
69+
"id": "da43952b",
70+
"metadata": {},
71+
"source": [
72+
"Then, init the interface with the desired model and API keys:"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": 3,
78+
"id": "4ba22f61",
79+
"metadata": {},
80+
"outputs": [
81+
{
82+
"name": "stderr",
83+
"output_type": "stream",
84+
"text": [
85+
"/home/francesco/HistoryRevisionismLLM/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
86+
" from .autonotebook import tqdm as notebook_tqdm\n"
87+
]
88+
},
89+
{
90+
"name": "stdout",
91+
"output_type": "stream",
92+
"text": [
93+
"INFO 07-14 11:25:17 [__init__.py:244] Automatically detected platform cuda.\n"
94+
]
95+
}
96+
],
97+
"source": [
98+
"from easyroutine.inference import LiteLLMInferenceModel, LiteLLMInferenceModelConfig\n",
99+
"config = LiteLLMInferenceModelConfig(\n",
100+
" model_name=\"gpt-4.1-nano-2025-04-14\",\n",
101+
" openai_api_key=OPENAI_API_KEY\n",
102+
")\n",
103+
"model = LiteLLMInferenceModel(config)"
104+
]
105+
},
106+
{
107+
"cell_type": "markdown",
108+
"id": "c8f26ca4",
109+
"metadata": {},
110+
"source": [
111+
"All the models are available in the `easyroutine.inference` module have the `.append_with_chat_template` method to append a message to the chat history with the specified role (either \"user\" or \"assistant\"). The `.chat` method than will handle the translation of the chat history to the specific model format and return the response.\n",
112+
"\n",
113+
"`append_with_chat_template` method take a message and a role as input, and returns a chat message in the format required by the model. It can also take a `chat_history` parameter to append the message to an existing chat history.\n"
114+
]
115+
},
116+
{
117+
"cell_type": "code",
118+
"execution_count": 4,
119+
"id": "d72acaa7",
120+
"metadata": {},
121+
"outputs": [
122+
{
123+
"name": "stdout",
124+
"output_type": "stream",
125+
"text": [
126+
"[{'role': 'user', 'content': 'What is the capital of France?'}]\n"
127+
]
128+
}
129+
],
130+
"source": [
131+
"chat_message = model.append_with_chat_template(message=\"What is the capital of France?\", role=\"user\")\n",
132+
"print(chat_message)"
133+
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": 5,
138+
"id": "add7f811",
139+
"metadata": {},
140+
"outputs": [
141+
{
142+
"name": "stdout",
143+
"output_type": "stream",
144+
"text": [
145+
"[Choices(finish_reason='stop', index=0, message=Message(content='The capital of France is Paris.', role='assistant', tool_calls=None, function_call=None, provider_specific_fields={'refusal': None}, annotations=[]), provider_specific_fields={})]\n"
146+
]
147+
}
148+
],
149+
"source": [
150+
"response = model.chat(chat_message)\n",
151+
"print(response)"
152+
]
153+
},
154+
{
155+
"cell_type": "markdown",
156+
"id": "740842a0",
157+
"metadata": {},
158+
"source": [
159+
"## Batched inference"
160+
]
161+
},
162+
{
163+
"cell_type": "code",
164+
"execution_count": 6,
165+
"id": "8ea393c6",
166+
"metadata": {},
167+
"outputs": [],
168+
"source": [
169+
"inputs = [\n",
170+
" model.append_with_chat_template(message=\"What is the capital of Italy?\", role=\"user\"),\n",
171+
" model.append_with_chat_template(message=\"What is the capital of Germany?\", role=\"user\"),\n",
172+
" model.append_with_chat_template(message=\"What is the capital of Spain?\", role=\"user\"),\n",
173+
"]"
174+
]
175+
},
176+
{
177+
"cell_type": "code",
178+
"execution_count": null,
179+
"id": "5dc3093f",
180+
"metadata": {},
181+
"outputs": [
182+
{
183+
"name": "stdout",
184+
"output_type": "stream",
185+
"text": [
186+
"[ModelResponse(id='chatcmpl-Bt9jD8Iy4A9h6OyHUuqQEN8s7qqqq', created=1752485135, model='gpt-4.1-nano-2025-04-14', object='chat.completion', system_fingerprint=None, choices=[Choices(finish_reason='stop', index=0, message=Message(content='The capital of Italy is Rome.', role='assistant', tool_calls=None, function_call=None, provider_specific_fields={'refusal': None}, annotations=[]), provider_specific_fields={})], usage=Usage(completion_tokens=7, prompt_tokens=14, total_tokens=21, completion_tokens_details=CompletionTokensDetailsWrapper(accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0, text_tokens=None), prompt_tokens_details=PromptTokensDetailsWrapper(audio_tokens=0, cached_tokens=0, text_tokens=None, image_tokens=None)), service_tier='default'), ModelResponse(id='chatcmpl-Bt9jDAATOtSORl4CtVSBblbsH3oX4', created=1752485135, model='gpt-4.1-nano-2025-04-14', object='chat.completion', system_fingerprint=None, choices=[Choices(finish_reason='stop', index=0, message=Message(content='The capital of Germany is Berlin.', role='assistant', tool_calls=None, function_call=None, provider_specific_fields={'refusal': None}, annotations=[]), provider_specific_fields={})], usage=Usage(completion_tokens=7, prompt_tokens=14, total_tokens=21, completion_tokens_details=CompletionTokensDetailsWrapper(accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0, text_tokens=None), prompt_tokens_details=PromptTokensDetailsWrapper(audio_tokens=0, cached_tokens=0, text_tokens=None, image_tokens=None)), service_tier='default'), ModelResponse(id='chatcmpl-Bt9jDPPEsRluyqeDvuXYKHPJhWVYE', created=1752485135, model='gpt-4.1-nano-2025-04-14', object='chat.completion', system_fingerprint='fp_38343a2f8f', choices=[Choices(finish_reason='stop', index=0, message=Message(content='The capital of Spain is Madrid.', role='assistant', tool_calls=None, function_call=None, provider_specific_fields={'refusal': None}, annotations=[]), provider_specific_fields={})], usage=Usage(completion_tokens=7, prompt_tokens=14, total_tokens=21, completion_tokens_details=CompletionTokensDetailsWrapper(accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0, text_tokens=None), prompt_tokens_details=PromptTokensDetailsWrapper(audio_tokens=0, cached_tokens=0, text_tokens=None, image_tokens=None)), service_tier='default')]\n"
187+
]
188+
}
189+
],
190+
"source": [
191+
"response = model.batch_chat(inputs)\n",
192+
"print([response[i][\"choices\"][0][\"message\"].content for i in range(len(response))] # Extract the content of the responses)"
193+
]
194+
}
195+
],
196+
"metadata": {
197+
"kernelspec": {
198+
"display_name": ".venv",
199+
"language": "python",
200+
"name": "python3"
201+
},
202+
"language_info": {
203+
"codemirror_mode": {
204+
"name": "ipython",
205+
"version": 3
206+
},
207+
"file_extension": ".py",
208+
"mimetype": "text/x-python",
209+
"name": "python",
210+
"nbconvert_exporter": "python",
211+
"pygments_lexer": "ipython3",
212+
"version": "3.12.3"
213+
}
214+
},
215+
"nbformat": 4,
216+
"nbformat_minor": 5
217+
}

0 commit comments

Comments
 (0)