-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsent_app.py
More file actions
74 lines (53 loc) · 1.92 KB
/
sent_app.py
File metadata and controls
74 lines (53 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from fastapi import FastAPI
import joblib
from pydantic import BaseModel
from typing import Union
from fastapi.responses import RedirectResponse
model = None
vectorizer = None
le = None
app = FastAPI(title="MoodLens-API")
model = joblib.load("model.joblib")
vectorizer = joblib.load("vectorizer.joblib")
le = joblib.load("label_encoder.joblib")
@app.get('/')
def root():
return RedirectResponse(url="/docs")
@app.get('/health')
def health():
return {"status":"ok",
"model":"loaded" if model is not None else "not loaded",
"vectorizer":"loaded" if vectorizer is not None else "not loaded",
"label_encoder":"loaded" if le is not None else "not loaded"}
class TextInput(BaseModel):
text: Union[str, list[str]]
@app.post('/predict')
def predict_sentiment(data: TextInput):
text = data.text
#------- Single text input ----------
if isinstance(text, str):
X = vectorizer.transform([text])
pred = model.predict(X)[0] # [0] extracts int
pred_label = le.inverse_transform([pred])[0] # []; expects array, 0 extracts neg
probs = model.predict_proba(X)[0] # [0] takes array out of list
prob = probs[pred] # selects prob based on pred.
return {
"Sentiment":pred_label,
"Confidence":float(prob)
}
#------ batch input --------------
elif all(isinstance(item, str) for item in text):
X = vectorizer.transform(text)
pred = model.predict(X)
pred_label = le.inverse_transform(pred)
prob = model.predict_proba(X).max(axis=1)
results = []
for txt,pr_label, probab in zip(text,pred_label, prob):
results.append({
"Text":txt,
"Sentiment":pr_label,
"Confidence":probab
})
return {"results":results}
else:
return {"error":"Input must be a string or list of strings"}