-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
finalize inference on img path or folder
- Loading branch information
Luca
committed
Jan 5, 2021
1 parent
f220178
commit ae40ddb
Showing
2 changed files
with
65 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,26 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Thu Jul 23 19:05:08 2020 | ||
@author: Nuvilabs-Luca | ||
@author: Nuvilabs - Luca Medeiros, [email protected] | ||
""" | ||
|
||
from utils.mmdet.apis import init_detector, inference_detector, show_result_pyplot | ||
import os | ||
from utils.mmdet.apis import init_detector, inference_detector | ||
import numpy as np | ||
import cv2 | ||
|
||
|
||
class Nuvi_RecycleNet(): | ||
def __init__(self, | ||
config_file='./model/model_config.py', | ||
checkpoint_file='./model/model_checkpoint.pth', | ||
output_dir='./output/', | ||
threshold=0.5, | ||
config_file, | ||
checkpoint_file, | ||
threshold, | ||
device='cuda:0', | ||
tta=True): | ||
|
||
self.config_file = config_file | ||
self.checkpoint_file = checkpoint_file | ||
self.output_dir = output_dir | ||
self.threshold = threshold # Only boxes with the score larger than this will be detected | ||
self.tta = tta # Perform TTA on not detected images | ||
|
||
|
@@ -59,7 +57,7 @@ def predict(self, img_path): | |
imageArray = cv2.imread(img_path) | ||
# Run inference using a model on a single picture -> img can be either path or array | ||
result = inference_detector(self.model, imageArray) | ||
res_idxs = [i for i, k in enumerate(result) if k.size != 0 and (k[:,4] > self.threshold).any()] | ||
res_idxs = [[i, k[0]] for i, k in enumerate(result[0]) if k.size != 0 and (k[:,4] > self.threshold).any()] | ||
if self.tta and not res_idxs: | ||
# Determine which augmentations to do when there is no detection. In order | ||
aug_type = ['LR', 'UDR', 'RT', 'Break'] | ||
|
@@ -73,15 +71,21 @@ def predict(self, img_path): | |
print('TTA type: ', augment_type) | ||
img_transformed = self.augmentIMG(imageArray, augment_type) | ||
result = inference_detector(self.model, img_transformed) | ||
res_idxs = [i for i, k in enumerate(result) if k.size != 0 and (k[:,4] > self.threshold).any()] | ||
res_idxs = [[i, k[0]] for i, k in enumerate(result[0]) if k.size != 0 and (k[:,4] > self.threshold).any()] | ||
|
||
json_result = self.make_json(res_idxs) | ||
|
||
return json_result | ||
|
||
def make_json(self, results): | ||
pass | ||
json_dict = {'Annotations': []} | ||
for result in results: | ||
label_idx = result[0] | ||
bbox = result[1][:4].tolist() | ||
score = result[1][-1] | ||
label_name = self.classes[label_idx] | ||
|
||
dict_result = {'Label': label_name, 'Bbox': bbox, 'Confidence': score} | ||
json_dict['Annotations'].append(dict_result) | ||
|
||
recycler = Nuvi_RecycleNet() | ||
json_result = recycler.predict('') | ||
return json_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author: Nuvilabs - Luca Medeiros, [email protected] | ||
""" | ||
|
||
import argparse | ||
import os.path as osp | ||
import glob | ||
from drs_infer import Nuvi_RecycleNet | ||
|
||
|
||
parser = argparse.ArgumentParser( | ||
description='Nuvilabs RecycleNet') | ||
parser.add_argument('--config_file', default='./model/model_config.py', type=str, | ||
help='Model configuration file path.') | ||
parser.add_argument('--checkpoint_file', default='./model/model_checkpoint.pth', type=str, | ||
help='Model checkpoint file path.') | ||
parser.add_argument('--img_path', default='./sample_img.jpg', type=str, | ||
help='Path of image or images.') | ||
parser.add_argument('--threshold', default=0.5, type=float, | ||
help='Only boxes with the score larger than this will be detected.') | ||
parser.add_argument('--use_tta', dest='use_tta', action='store_true', | ||
help='Either use TTA to help detect images.') | ||
parser.set_defaults(use_tta=False) | ||
|
||
if __name__ == '__main__': | ||
|
||
arg = parser.parse_args() | ||
recyclernet = Nuvi_RecycleNet(arg.config_file, | ||
arg.checkpoint_file, | ||
arg.threshold, | ||
tta=arg.use_tta) | ||
image_list = [] | ||
if osp.isdir(arg.img_path): | ||
arg.img_path = osp.join(arg.img_path, '') | ||
img_ext = ('png', 'jpg', 'JPG', 'jpeg') | ||
for ext in img_ext: | ||
image_list.extend(glob.glob(arg.img_path + '*.' + ext)) | ||
|
||
elif osp.isfile(arg.img_path): | ||
image_list.append(arg.img_path) | ||
|
||
for path in image_list: | ||
json_result = recyclernet.predict(path) | ||
print(json_result) |