forked from brettin/ARC
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathARGO.py
More file actions
53 lines (45 loc) · 1.44 KB
/
ARGO.py
File metadata and controls
53 lines (45 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#
# A wrapper class for the Argonne Argo LLM service
#
import os
import requests
import json
MODEL_GPT35 = "gpt35"
MODEL_GPT4 = "gpt4"
class ArgoWrapper:
default_url = "https://apps-dev.inside.anl.gov/argoapi/api/v1/resource/chat/"
def __init__(self,
url = None,
model = MODEL_GPT35,
system = "",
temperature = 0.8,
top_p=0.7,
user = os.getenv("USER"))-> None:
self.url = url
if self.url is None:
self.url = ArgoWrapper.default_url
self.model = model
self.temperature = temperature
self.top_p = top_p
self.user = user
self.system = ""
def invoke(self, prompt: str):
headers = {
"Content-Type": "application/json"
}
data = {
"user": self.user,
"model": self.model,
"system": self.system,
"prompt": [prompt],
"stop": [],
"temperature": self.temperature,
"top_p": self.top_p
}
data_json = json.dumps(data)
response = requests.post(self.url, headers=headers, data=data_json)
if response.status_code == 200:
parsed = json.loads(response.text)
return parsed
else:
raise Exception(f"Request failed with status code: {response.status_code}")