Skip to content

Commit

Permalink
finalize inference on img path or folder
Browse files Browse the repository at this point in the history
  • Loading branch information
Luca committed Jan 5, 2021
1 parent f220178 commit ae40ddb
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 14 deletions.
32 changes: 18 additions & 14 deletions drs_infer.py → inference.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,26 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jul 23 19:05:08 2020
@author: Nuvilabs-Luca
@author: Nuvilabs - Luca Medeiros, [email protected]
"""

from utils.mmdet.apis import init_detector, inference_detector, show_result_pyplot
import os
from utils.mmdet.apis import init_detector, inference_detector
import numpy as np
import cv2


class Nuvi_RecycleNet():
def __init__(self,
config_file='./model/model_config.py',
checkpoint_file='./model/model_checkpoint.pth',
output_dir='./output/',
threshold=0.5,
config_file,
checkpoint_file,
threshold,
device='cuda:0',
tta=True):

self.config_file = config_file
self.checkpoint_file = checkpoint_file
self.output_dir = output_dir
self.threshold = threshold # Only boxes with the score larger than this will be detected
self.tta = tta # Perform TTA on not detected images

Expand Down Expand Up @@ -59,7 +57,7 @@ def predict(self, img_path):
imageArray = cv2.imread(img_path)
# Run inference using a model on a single picture -> img can be either path or array
result = inference_detector(self.model, imageArray)
res_idxs = [i for i, k in enumerate(result) if k.size != 0 and (k[:,4] > self.threshold).any()]
res_idxs = [[i, k[0]] for i, k in enumerate(result[0]) if k.size != 0 and (k[:,4] > self.threshold).any()]
if self.tta and not res_idxs:
# Determine which augmentations to do when there is no detection. In order
aug_type = ['LR', 'UDR', 'RT', 'Break']
Expand All @@ -73,15 +71,21 @@ def predict(self, img_path):
print('TTA type: ', augment_type)
img_transformed = self.augmentIMG(imageArray, augment_type)
result = inference_detector(self.model, img_transformed)
res_idxs = [i for i, k in enumerate(result) if k.size != 0 and (k[:,4] > self.threshold).any()]
res_idxs = [[i, k[0]] for i, k in enumerate(result[0]) if k.size != 0 and (k[:,4] > self.threshold).any()]

json_result = self.make_json(res_idxs)

return json_result

def make_json(self, results):
pass
json_dict = {'Annotations': []}
for result in results:
label_idx = result[0]
bbox = result[1][:4].tolist()
score = result[1][-1]
label_name = self.classes[label_idx]

dict_result = {'Label': label_name, 'Bbox': bbox, 'Confidence': score}
json_dict['Annotations'].append(dict_result)

recycler = Nuvi_RecycleNet()
json_result = recycler.predict('')
return json_dict
47 changes: 47 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: Nuvilabs - Luca Medeiros, [email protected]
"""

import argparse
import os.path as osp
import glob
from drs_infer import Nuvi_RecycleNet


parser = argparse.ArgumentParser(
description='Nuvilabs RecycleNet')
parser.add_argument('--config_file', default='./model/model_config.py', type=str,
help='Model configuration file path.')
parser.add_argument('--checkpoint_file', default='./model/model_checkpoint.pth', type=str,
help='Model checkpoint file path.')
parser.add_argument('--img_path', default='./sample_img.jpg', type=str,
help='Path of image or images.')
parser.add_argument('--threshold', default=0.5, type=float,
help='Only boxes with the score larger than this will be detected.')
parser.add_argument('--use_tta', dest='use_tta', action='store_true',
help='Either use TTA to help detect images.')
parser.set_defaults(use_tta=False)

if __name__ == '__main__':

arg = parser.parse_args()
recyclernet = Nuvi_RecycleNet(arg.config_file,
arg.checkpoint_file,
arg.threshold,
tta=arg.use_tta)
image_list = []
if osp.isdir(arg.img_path):
arg.img_path = osp.join(arg.img_path, '')
img_ext = ('png', 'jpg', 'JPG', 'jpeg')
for ext in img_ext:
image_list.extend(glob.glob(arg.img_path + '*.' + ext))

elif osp.isfile(arg.img_path):
image_list.append(arg.img_path)

for path in image_list:
json_result = recyclernet.predict(path)
print(json_result)

0 comments on commit ae40ddb

Please sign in to comment.