-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCLIP_REST_API.py
More file actions
79 lines (66 loc) · 2.63 KB
/
Copy pathCLIP_REST_API.py
File metadata and controls
79 lines (66 loc) · 2.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from typing import Optional
from fastapi import FastAPI, Form, File, UploadFile, HTTPException
from PIL import Image
from pydantic import BaseModel
from transformers import CLIPProcessor, CLIPModel
import torch.nn.functional as F
import io
from urllib.parse import urlparse
import requests
app = FastAPI()
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
clip_model = CLIPModel.from_pretrained(model_name)
class PredictionResult(BaseModel):
prediction: str
label_probabilities: dict
original_payload: dict
async def get_image_from_upload(image: UploadFile):
try:
return Image.open(io.BytesIO(await image.read())).resize((224, 224))
except IOError:
raise HTTPException(status_code=400, detail="Invalid image file.")
async def get_image_from_url(url: str):
if urlparse(url).scheme not in ['http', 'https']:
raise HTTPException(status_code=400, detail="Invalid image URL.")
try:
response = requests.get(url)
response.raise_for_status()
return Image.open(io.BytesIO(response.content)).resize((224, 224))
except (requests.RequestException, IOError):
raise HTTPException(status_code=400, detail="Unable to download image from the provided URL.")
@app.post("/classify/")
async def classify_image(
location: Optional[str] = Form(None),
classifier: Optional[str] = Form("clip"),
classes: Optional[str] = Form("nudity, wad, offensive, face-attributes, gore"),
image: UploadFile = File(None)
):
if not location and not image:
raise HTTPException(status_code=400, detail="Please provide either 'location' or 'image'.")
classes = classes.strip(' ').split(', ')
if image:
pil_image = await get_image_from_upload(image)
else:
pil_image = await get_image_from_url(location)
inputs = processor(text=classes, images=pil_image, return_tensors="pt", padding=True)
outputs = clip_model(**inputs)
print(outputs.logits_per_image)
print('/n')
probs = F.softmax(outputs.logits_per_image, dim=-1).tolist()[0]
print(F.softmax(outputs.logits_per_image, dim=-1))
print(probs)
print('/n')
label_probs = dict(zip(classes, probs))
print(zip(classes, probs))
print(label_probs)
prediction = max(label_probs, key=label_probs.get)
result = PredictionResult(
prediction=prediction,
label_probabilities=label_probs,
original_payload={
"location": location,
"classifier": "clip"
}
)
return {"result": result}