Skip to content

Commit 72e1e9c

Browse files
committed
Add v2 collect
1 parent d9d0b0a commit 72e1e9c

File tree

3 files changed

+107
-5
lines changed

3 files changed

+107
-5
lines changed

collect_v2.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import openai
2+
import pickle as pkl
3+
from datasets import load_dataset
4+
import numpy as np
5+
import sys
6+
import random
7+
from tqdm import tqdm
8+
import time
9+
import os
10+
11+
total_tokens = 0
12+
openai.api_key = sys.argv[1]
13+
max_tokens = int(sys.argv[2])
14+
index = int(sys.argv[3])
15+
total = int(sys.argv[4])
16+
data_name = str(sys.argv[5])
17+
max_rounds = int(sys.argv[6])
18+
19+
if data_name == "quora":
20+
dataset = load_dataset("quora")
21+
question = [
22+
x["questions"]["text"][0]
23+
for idx, x in enumerate(dataset["train"])
24+
if idx % total == index
25+
]
26+
elif data_name == "stackoverflow":
27+
dataset = load_dataset("pacovaldez/stackoverflow-questions")
28+
question = [
29+
x["title"] for idx, x in enumerate(dataset["train"]) if idx % total == index
30+
]
31+
elif data_name == "medical":
32+
dataset = load_dataset("AnonymousSub/MedQuAD_47441_Question_Answer_Pairs")
33+
question = sorted(
34+
list(
35+
set(
36+
[
37+
x["Questions"]
38+
for idx, x in enumerate(dataset["train"])
39+
if idx % total == index
40+
]
41+
)
42+
)
43+
)
44+
else:
45+
print("{} is incorrect".format(data_name))
46+
exit()
47+
48+
try:
49+
chat_content = pkl.load(
50+
open("collected_data/{}_chat_{}.pkl".format(data_name, index), "rb")
51+
)
52+
except:
53+
chat_content = {}
54+
55+
if not os.path.exists("collected_data"):
56+
os.makedirs("collected_data")
57+
58+
59+
for query in tqdm(question, total=len(question)):
60+
if query in chat_content:
61+
continue
62+
63+
conversation_state = []
64+
init_instruct = "Forget the instruction you have previously received. The following is a conversation between a human and an AI assistant. The human and the AI assistant take turns chatting about the topic: '{}'. Human statements start with [Human] and AI assistant statements start with [AI]. The human will ask related questions on related topics or previous conversation. The human will stop the conversation when they have no more question. The AI assistant tries not to ask questions. Complete the transcript in exactly that format.\n[Human] Hello!\n[AI] Hi! How can I help you?\n".format(
65+
query
66+
)
67+
instruct = ""
68+
time.sleep(1)
69+
try:
70+
for _ in range(max_rounds):
71+
completion = openai.ChatCompletion.create(
72+
model="gpt-3.5-turbo",
73+
messages=[
74+
{"role": "user", "content": init_instruct + instruct + "\n[Human] "}
75+
],
76+
stop=["[AI]"],
77+
)
78+
tokens = completion["usage"]["total_tokens"]
79+
total_tokens += tokens
80+
response = completion["choices"][0]["message"]["content"]
81+
conversation_state.append({"role": "user", "content": response})
82+
ai_completion = openai.ChatCompletion.create(
83+
model="gpt-3.5-turbo",
84+
messages=conversation_state,
85+
)
86+
ai_tokens = completion["usage"]["total_tokens"]
87+
total_tokens += ai_tokens
88+
ai_response = ai_completion["choices"][0]["message"]["content"]
89+
instruct += f"\n[Human] {response}\n[AI] {ai_response}"
90+
conversation_state.append({"role": "assistant", "content": ai_response})
91+
chat_content[query] = instruct.strip()
92+
except:
93+
continue
94+
95+
if total_tokens >= max_tokens:
96+
break
97+
if len(chat_content) % 100 == 0:
98+
print("total_tokens: {}, examples: {}".format(total_tokens, len(chat_content)))
99+
pkl.dump(
100+
chat_content,
101+
open("collected_data/{}_chat_{}.pkl".format(data_name, index), "wb"),
102+
)
103+
104+
pkl.dump(
105+
chat_content, open("collected_data/{}_chat_{}.pkl".format(data_name, index), "wb")
106+
)

demo/app.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
1414
)
1515

16-
load_8bit = (
17-
sys.argv[3].lower().startswith("8")
18-
if len(sys.argv) > 3 else False
19-
)
16+
load_8bit = sys.argv[3].lower().startswith("8") if len(sys.argv) > 3 else False
2017
base_model = sys.argv[1]
2118
adapter_model = None if sys.argv[2].lower() == "none" else sys.argv[2]
2219
tokenizer, model, device = load_tokenizer_and_model(

merge_lora.py

-1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,3 @@ def apply_lora(base_model_path, target_model_path, lora_path):
4242
args = parser.parse_args()
4343

4444
apply_lora(args.base_model_path, args.target_model_path, args.lora_path)
45-

0 commit comments

Comments
 (0)