Skip to content

Commit

Permalink
add automatic model search for opencv launcher
Browse files Browse the repository at this point in the history
  • Loading branch information
APrigarina committed Apr 5, 2022
1 parent 10f54a1 commit 6602fb9
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def provide_precision_and_layout(launchers, input_precisions, input_layouts):


def provide_model_type(launcher, arguments):
if 'model_type' in arguments:
if 'model_type' in arguments and arguments.model_type is not None:
launcher['_model_type'] = arguments.model_type
if launcher['framework'] in ['dlsdk', 'openvino', 'g-api'] and 'model_is_blob' in arguments:
launcher['_model_is_blob'] = arguments.model_is_blob
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

import re
from collections import OrderedDict
from pathlib import Path
import numpy as np
import cv2

from ..config import PathField, StringField, ConfigError, ListInputsField
from ..logging import print_info
from .launcher import Launcher, LauncherConfigValidator
from ..utils import get_or_parse_value
from ..utils import get_or_parse_value, get_path

DEVICE_REGEX = r'(?P<device>cpu$|gpu|gpu_fp16)?'
BACKEND_REGEX = r'(?P<backend>ocv|ie)?'
Expand Down Expand Up @@ -63,8 +64,11 @@ class OpenCVLauncher(Launcher):
def parameters(cls):
parameters = super().parameters()
parameters.update({
'model': PathField(description="Path to model file."),
'weights': PathField(description="Path to weights file.", optional=True, default='', check_exists=False),
'model': PathField(description="Path to model file.", file_or_directory=True),
'weights': PathField(
description="Path to weights file.", optional=True,
check_exists=False, file_or_directory=True
),
'device': StringField(
regex=DEVICE_REGEX, choices=OpenCVLauncher.TARGET_DEVICES.keys(),
description="Device name: {}".format(', '.join(OpenCVLauncher.TARGET_DEVICES.keys()))
Expand Down Expand Up @@ -100,8 +104,10 @@ def __init__(self, config_entry: dict, *args, **kwargs):
raise ConfigError('{} is not supported device'.format(selected_device))

if not self._delayed_model_loading:
self.model = self.get_value_from_config('model')
self.weights = self.get_value_from_config('weights')
self.model, self.weights = self.automatic_model_search(self._model_name,
self.get_value_from_config('model'), self.get_value_from_config('weights'),
self.get_value_from_config('_model_type')
)
self.network = self.create_network(self.model, self.weights)
self._inputs_shapes = self.get_inputs_from_config(self.config)
self.network.setInputsNames(list(self._inputs_shapes.keys()))
Expand Down Expand Up @@ -130,6 +136,71 @@ def batch(self):
def output_blob(self):
return next(iter(self.output_names))

def automatic_model_search(self, model_name, model_cfg, weights_cfg, model_type=None):
model_type_ext = {
'xml': 'xml',
'blob': 'blob',
'onnx': 'onnx',
'caffe': 'prototxt',
'tf': 'pb'
}
def get_model_by_suffix(model_name, model_dir, suffix):
model_list = list(Path(model_dir).glob('{}.{}'.format(model_name, suffix)))
if not model_list:
model_list = list(Path(model_dir).glob('*.{}'.format(suffix)))
if not model_list:
model_list = list(Path(model_dir).parent.rglob('*.{}'.format(suffix)))
return model_list

def get_model():
model = Path(model_cfg)
if not model.is_dir():
accepted_suffixes = list(model_type_ext.values())
if model.suffix[1:] not in accepted_suffixes:
raise ConfigError('Models with following suffixes are allowed: {}'.format(accepted_suffixes))
print_info('Found model {}'.format(model))
return model, model.suffix == '.blob'
model_list = []
if model_type is not None:
model_list = get_model_by_suffix(model_name, model, model_type_ext[model_type])
else:
for ext in model_type_ext.values():
model_list = get_model_by_suffix(model_name, model, ext)
if model_list:
break
if not model_list:
raise ConfigError('suitable model is not found')
if len(model_list) != 1:
raise ConfigError('More than one model matched, please specify explicitly')
model = model_list[0]
print_info('Found model {}'.format(model))
return model, model.suffix == '.blob'

model, is_blob = get_model()
if is_blob:
return model, None
weights = weights_cfg
if (weights is None or Path(weights).is_dir()) and model.suffix != '.onnx':
weights_dir = weights or model.parent
weights_list = []
if model.suffix == '.xml':
weights = Path(weights_dir) / model.name.replace('xml', 'bin')
else:
if model.suffix == '.prototxt':
weights_list = list(Path(weights_dir).glob('*.{}'.format('caffemodel')))
if not weights_list:
raise ConfigError('Suitable weights is not detected')
if len(weights_list) != 1:
raise ConfigError('Several suitable weights found, please specify required explicitly')
weights = weights_list[0]
if weights is not None:
accepted_weights_suffixes = ['.bin', '.caffemodel']
if weights.suffix not in accepted_weights_suffixes:
raise ConfigError('Weights with following suffixes are allowed: {}'.format(accepted_weights_suffixes))
print_info('Found weights {}'.format(get_path(weights)))

return model, weights

def predict(self, inputs, metadata=None, **kwargs):
"""
Args:
Expand Down

0 comments on commit 6602fb9

Please sign in to comment.