-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathyolo_extract_objects.py
executable file
·80 lines (60 loc) · 2.74 KB
/
yolo_extract_objects.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python3
# Copyright 2021 Sergei Solodovnikov
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
from PIL import Image
def extract_object_by_id(class_id: int, dataset: Path, target: Path):
target = target / str(class_id)
target.mkdir(parents=True, exist_ok=True)
for yolo_img_path in dataset.glob('*.jpg'):
yolo_txt_path = dataset / f"{yolo_img_path.stem}.txt"
if not yolo_txt_path.exists():
print(f"{yolo_img_path.name} doesn't have YOLO txt file")
continue
with Image.open(yolo_img_path) as img:
img_w = img.width
img_h = img.height
print(f"Process: {yolo_img_path.name}")
with yolo_txt_path.open(mode='r') as yolo_txt:
for i, yolo_line in enumerate(yolo_txt):
obj_class_id, obj_xc, obj_yc, obj_w, obj_h, = yolo_line.split(
' ')
obj_class_id = int(obj_class_id)
if obj_class_id != class_id:
continue
obj_xc = img_w * float(obj_xc)
obj_yc = img_h * float(obj_yc)
obj_hw = img_w * float(obj_w) * 0.5
obj_hh = img_h * float(obj_h) * 0.5
img.crop((obj_xc - obj_hw,
obj_yc - obj_hh,
obj_xc + obj_hw,
obj_yc + obj_hh)).save(target.joinpath(f"{yolo_img_path.stem}_{i}.png"))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Helper to crop object images from YOLO dataset')
parser.add_argument('--class_id',
required=True,
help='YOLO class id to crop')
parser.add_argument('--dataset',
default='./yolo/dataset/',
help='Path to the YOLO dataset')
parser.add_argument('--target',
default='./yolo/objects/',
help='Path where cropped objects should be stored')
args = parser.parse_args()
extract_object_by_id(int(args.class_id),
Path(args.dataset),
Path(args.target))