-
Notifications
You must be signed in to change notification settings - Fork 0
/
cut_avatars.py
executable file
·81 lines (62 loc) · 3.28 KB
/
cut_avatars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import warnings
import torch
from torchvision import transforms
from utils import *
from PIL import Image, ImageDraw, ImageFont
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model checkpoint
# Transforms
resize = transforms.Resize((300, 300))
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def detect(original_image, min_score, max_overlap, top_k, suppress=None , model=None, num=None):
# Transform
image = normalize(to_tensor(resize(original_image)))
# Move to default device
image = image.to(device)
# Forward prop.
predicted_locs, predicted_scores = model(image.unsqueeze(0))
# Detect objects in SSD output
det_boxes, det_labels, det_scores = model.detect_objects(predicted_locs, predicted_scores, min_score=min_score, max_overlap=max_overlap, top_k=top_k)
# Move detections to the CPU
det_boxes = det_boxes[0].to('cpu')
# Transform to original image dimensions
original_dims = torch.FloatTensor([original_image.width, original_image.height, original_image.width, original_image.height]).unsqueeze(0)
det_boxes = det_boxes * original_dims
# Decode class integer labels
det_labels = [rev_label_map[l] for l in det_labels[0].to('cpu').tolist()]
# If no objects found, the detected labels will be set to ['0.'], i.e. ['background'] in SSD300.detect_objects() in model.py
if det_labels == ['background']:
return original_image # Just return original image
# Annotate
annotated_image = original_image
draw = ImageDraw.Draw(annotated_image)
font = ImageFont.truetype("./arial.ttf", 15)
# Suppress specific classes, if needed
for i in range(det_boxes.size(0)):
box_location = det_boxes[i].tolist()
crop_path="static/detection/crop/" + str(num) + "-crop_" + str(i).zfill(3) + ".jpg"
nickname_path="static/detection/nickname/" + str(num) + "-crop_" + str(i).zfill(3) + ".jpg"
# 아바타 이미지 잘라내기
cropped = annotated_image.crop([box_location[0]-5,box_location[1]-5,box_location[2]+5,box_location[3]+5])
cropped.save(crop_path,"JPEG")
# 아바타 닉네임 이미지 잘래내기
# nickname_cropped = annotated_image.crop([box_location[0]-20,box_location[1]-100,box_location[2]+20,box_location[3]-150])
# nickname_cropped.save(nickname_path,"JPEG")
for i in range(det_boxes.size(0)):
if suppress is not None:
if det_labels[i] in suppress:
continue
# Boxes
box_location = det_boxes[i].tolist()
draw.rectangle(xy=box_location, outline=label_color_map[det_labels[i]])
draw.rectangle(xy=[l + 1. for l in box_location], outline=label_color_map[det_labels[i]])
# Text
text_size = font.getsize(det_labels[i].upper())
text_location = [box_location[0] + 2., box_location[1] - text_size[1]]
textbox_location = [box_location[0], box_location[1] - text_size[1], box_location[0] + text_size[0] + 4., box_location[1]]
draw.rectangle(xy=textbox_location, fill=label_color_map[det_labels[i]])
draw.text(xy=text_location, text=det_labels[i].upper(), fill='white', font=font)
del draw
return annotated_image