-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathllm_utils.py
81 lines (67 loc) · 3.2 KB
/
llm_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import time
import openai
from vllm import LLM, SamplingParams
from openai import OpenAI
from functools import partial
from prompts import icl_user_prompt, icl_ass_prompt
def llm_init(model_name, tensor_parallel_size=1, max_seq_len_to_capture=8192, max_tokens=4000, seed=0, temperature=0, frequency_penalty=0):
if "gpt" not in model_name:
client = LLM(model=model_name, tensor_parallel_size=tensor_parallel_size, max_seq_len_to_capture=max_seq_len_to_capture)
sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens,
frequency_penalty=frequency_penalty)
llm = partial(client.chat, sampling_params=sampling_params, use_tqdm=False)
else:
# api_key = input("Enter OpenAI API key: ")
# os.environ["OPENAI_API_KEY"] = api_key
client = OpenAI()
llm = partial(client.chat.completions.create, model=model_name, seed=seed, temperature=temperature, max_tokens=max_tokens)
return llm
def get_outputs(outputs, model_name):
if "gpt" not in model_name:
return outputs[0].outputs[0].text
else:
return outputs.choices[0].message.content
def llm_inf(llm, prompts, mode, model_name):
res = []
if 'sys' in mode:
conversation = [{"role": "system", "content": prompts['sys_query']}]
if 'icl' in mode:
conversation.append({"role": "user", "content": icl_user_prompt})
conversation.append({"role": "assistant", "content": icl_ass_prompt})
if 'sys' in mode:
conversation.append({"role": "user", "content": prompts['user_query']})
outputs = get_outputs(llm(messages=conversation), model_name)
res.append(outputs)
if 'sys_cot' in mode:
if 'clear' in mode:
conversation = []
conversation.append({"role": "assistant", "content": outputs})
conversation.append({"role": "user", "content": prompts['cot_query']})
outputs = get_outputs(llm(messages=conversation), model_name)
res.append(outputs)
elif "dc" in mode:
if 'ans:' not in res[0].lower() or "ans: not available" in res[0].lower() or "ans: no information available" in res[0].lower():
conversation.append({"role": "user", "content": prompts['cot_query']})
outputs = get_outputs(llm(messages=conversation), model_name)
res[0] = outputs
res.append("")
else:
res.append("")
return res
def llm_inf_with_retry(llm, each_qa, llm_mode, model_name, max_retries):
retries = 0
while retries < max_retries:
try:
return llm_inf(llm, each_qa, llm_mode, model_name)
except openai.RateLimitError as e:
wait_time = (2 ** retries) * 5 # Exponential backoff
print(f"Rate limit error encountered. Retrying in {wait_time} seconds...")
time.sleep(wait_time)
retries += 1
raise Exception("Max retries exceeded. Please check your rate limits or try again later.")
def llm_inf_all(llm, each_qa, llm_mode, model_name, max_retries=5):
if 'gpt' in model_name:
return llm_inf_with_retry(llm, each_qa, llm_mode, model_name, max_retries)
else:
return llm_inf(llm, each_qa, llm_mode, model_name)