diff --git a/.gitignore b/.gitignore index 63754ecf1..ca0c9bc5c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,15 @@ resources/icons/.DS_Store resources.py labelImg.egg-info* +zachary_notes.txt *.pyc .*.swp build/ dist/ +uml/ +venv/ tags cscope* diff --git a/dev_log.txt b/dev_log.txt new file mode 100644 index 000000000..5fd0071a8 --- /dev/null +++ b/dev_log.txt @@ -0,0 +1,113 @@ +Todo: + refactor labelImg repository, mainly focus on io operations + refactoring classes: LabelFile, LabelFileFormat, LabelIO + +2022/10/30: + In labelImg.py: + add comments to note some code smells for refactory + + In labelImg.py, def save_labels(): + code smell: + the function originally uses multiple if else statements to check the label format, + and then call the corresponding method in LabelFile object to write the file + + solving: + making a IOMAP(dictionary) in LabelFile.py + whenever a label needs to be save, map the label format to the corresponding reader object + so we have to allign the init method of all writers + also the save() method in each writers + for future extentions, writers must follow the format + + result: + unit test passed + + futher work: + add interface for writer that restrict the save() method and init() method + (done)link labelformat with labelfile to remove dictionary mapping + +2022/10/31: + In labelImg.py, def set_format(): + code smell: + using if else statements to set each format + + solving: + making LabelFileFomat an abstract class, contain corresponding writer/reader(io), icon, text + then implement the LabelFileFomat object with current available formats + + whenever the labelformat is specified, then the corresponding icon, text, io will be noticed, + because the abstract class force the derived class to have the corresponding attributes, + the only if-else statement remain is the toggle button + + result: + unit test passed + + futher work: + solve if-else in def change_format(), probably using a list, + or making a patch to search all derived class of LabelFileFormat + +2022/11/6: + In labelFile.py + code smell: + Extracting abstract class for labelFileFormat does not make sence. + Ideally, abstract class provide a set of abstract method + and concrete class must implement them. + But for the fileformat design, we only need every format to own corresponding + io as attribute. not method + + solving: + thus, I redesign the labelFileFormat class as a concrete class + Each format is a instance of this class, the init method force each instance to have + the required attributes assigned + + result: + unit test passed + + futher work: + explore some design pattern to apply on format/io interfaces + +2022/11/26: + Debugging for application: + though we passed the unit test, I found that the application will crash if some functions are used + 1. the toggle label file format button + 2. save button + 3. openfile button + + the toggle label file format button: + this error occur when we compare two LabelFileFormat objects, in python the compairson is based on pointer address + however, here we want to compare only the attributes for the identity. + Thus, a __eq__ method is written into labelFileFormat class, then the issue is solved + + save button: + for each save operator, we initialize a new writer based on the format + it is important to keep each writer consistent of input parameters + currently, we cannot fuse writer into label file format operation, this will be left to foward work + + openfile button: + this error occur when detecting if the file is available for import + we are checking validity by checking image format and label suffix + the suffix of all labels are not saved, we can only see one suffix at a time + to solve this issue, we add suffixes attribute in labelFileFormat class + thus we can now see all the suffixes of the objects + + futher work: + fuse writer into labelFileFormat object + +2022/12/09: + Today's work: + 1. solving if-else statements in load_file_by_filename() + 2. redefining clear parameter for writer objects + 3. E2E debugging for labelfile io operation + + Code smell: + long if-else statements for loading label file + + Solve: + search for matching suffix in LabelFileFormat object + then use the method of LabelFileFormat.read to load any file + + futher work: abstract interface for writer and reader objects + +2022/12/24: + Today's work: + build abstract class for file reader and file writer + for future extension, any format should inherit the writer and reader class for implementation diff --git a/labelImg.py b/labelImg.py index efd8a2976..22c521b76 100755 --- a/labelImg.py +++ b/labelImg.py @@ -37,7 +37,8 @@ from libs.lightWidget import LightWidget from libs.labelDialog import LabelDialog from libs.colorDialog import ColorDialog -from libs.labelFile import LabelFile, LabelFileError, LabelFileFormat +from libs.labelFile import LabelFile, LabelFileError +from libs.labelFileFormat import LabelFileFormat, PascalVoc, Yolo, CreateML from libs.toolBar import ToolBar from libs.pascal_voc_io import PascalVocReader from libs.pascal_voc_io import XML_EXT @@ -90,7 +91,7 @@ def __init__(self, default_filename=None, default_prefdef_class_file=None, defau # Save as Pascal voc xml self.default_save_dir = default_save_dir - self.label_file_format = settings.get(SETTING_LABEL_FILE_FORMAT, LabelFileFormat.PASCAL_VOC) + self.label_file_format = settings.get(SETTING_LABEL_FILE_FORMAT, PascalVoc) # For loading all image under a directory self.m_img_list = [] @@ -242,20 +243,9 @@ def __init__(self, default_filename=None, default_prefdef_class_file=None, defau save = action(get_str('save'), self.save_file, 'Ctrl+S', 'save', get_str('saveDetail'), enabled=False) - def get_format_meta(format): - """ - returns a tuple containing (title, icon_name) of the selected format - """ - if format == LabelFileFormat.PASCAL_VOC: - return '&PascalVOC', 'format_voc' - elif format == LabelFileFormat.YOLO: - return '&YOLO', 'format_yolo' - elif format == LabelFileFormat.CREATE_ML: - return '&CreateML', 'format_createml' - - save_format = action(get_format_meta(self.label_file_format)[0], + save_format = action(self.label_file_format.text, self.change_format, 'Ctrl+Y', - get_format_meta(self.label_file_format)[1], + self.label_file_format.icon, get_str('changeSaveFormat'), enabled=True) save_as = action(get_str('saveAs'), self.save_file_as, @@ -548,35 +538,20 @@ def keyPressEvent(self, event): # Draw rectangle if Ctrl is pressed self.canvas.set_drawing_shape_to_square(True) - # Support Functions # def set_format(self, save_format): - if save_format == FORMAT_PASCALVOC: - self.actions.save_format.setText(FORMAT_PASCALVOC) - self.actions.save_format.setIcon(new_icon("format_voc")) - self.label_file_format = LabelFileFormat.PASCAL_VOC - LabelFile.suffix = XML_EXT - - elif save_format == FORMAT_YOLO: - self.actions.save_format.setText(FORMAT_YOLO) - self.actions.save_format.setIcon(new_icon("format_yolo")) - self.label_file_format = LabelFileFormat.YOLO - LabelFile.suffix = TXT_EXT - - elif save_format == FORMAT_CREATEML: - self.actions.save_format.setText(FORMAT_CREATEML) - self.actions.save_format.setIcon(new_icon("format_createml")) - self.label_file_format = LabelFileFormat.CREATE_ML - LabelFile.suffix = JSON_EXT - + self.actions.save_format.setText(save_format.text) + self.actions.save_format.setIcon(new_icon(save_format.icon)) + self.label_file_format = save_format + def change_format(self): - if self.label_file_format == LabelFileFormat.PASCAL_VOC: - self.set_format(FORMAT_YOLO) - elif self.label_file_format == LabelFileFormat.YOLO: - self.set_format(FORMAT_CREATEML) - elif self.label_file_format == LabelFileFormat.CREATE_ML: - self.set_format(FORMAT_PASCALVOC) + if self.label_file_format in LabelFileFormat.formats: + index = self.label_file_format.formats.index(self.label_file_format) + self.label_file_format = LabelFileFormat.formats[index+1 if index+1 < len(self.label_file_format.formats) else 0] + self.set_format(self.label_file_format) else: raise ValueError('Unknown label file format.') + #! todo: error when only label file is change then save + #! this should not be dirty if no image is imported/ no label is plotted self.set_dirty() def no_shapes(self): @@ -881,6 +856,7 @@ def save_labels(self, annotation_file_path): if self.label_file is None: self.label_file = LabelFile() self.label_file.verified = self.canvas.verified + self.label_file.label_file_format = self.label_file_format def format_shape(s): return dict(label=s.label, @@ -893,24 +869,8 @@ def format_shape(s): shapes = [format_shape(shape) for shape in self.canvas.shapes] # Can add different annotation formats here try: - if self.label_file_format == LabelFileFormat.PASCAL_VOC: - if annotation_file_path[-4:].lower() != ".xml": - annotation_file_path += XML_EXT - self.label_file.save_pascal_voc_format(annotation_file_path, shapes, self.file_path, self.image_data, - self.line_color.getRgb(), self.fill_color.getRgb()) - elif self.label_file_format == LabelFileFormat.YOLO: - if annotation_file_path[-4:].lower() != ".txt": - annotation_file_path += TXT_EXT - self.label_file.save_yolo_format(annotation_file_path, shapes, self.file_path, self.image_data, self.label_hist, - self.line_color.getRgb(), self.fill_color.getRgb()) - elif self.label_file_format == LabelFileFormat.CREATE_ML: - if annotation_file_path[-5:].lower() != ".json": - annotation_file_path += JSON_EXT - self.label_file.save_create_ml_format(annotation_file_path, shapes, self.file_path, self.image_data, - self.label_hist, self.line_color.getRgb(), self.fill_color.getRgb()) - else: - self.label_file.save(annotation_file_path, shapes, self.file_path, self.image_data, - self.line_color.getRgb(), self.fill_color.getRgb()) + self.label_file.save(annotation_file_path, shapes, self.file_path, self.image_data, + self.label_hist) print('Image:{0} -> Annotation:{1}'.format(self.file_path, annotation_file_path)) return True except LabelFileError as e: @@ -1180,32 +1140,11 @@ def counter_str(self): def show_bounding_box_from_annotation_file(self, file_path): if self.default_save_dir is not None: basename = os.path.basename(os.path.splitext(file_path)[0]) - xml_path = os.path.join(self.default_save_dir, basename + XML_EXT) - txt_path = os.path.join(self.default_save_dir, basename + TXT_EXT) - json_path = os.path.join(self.default_save_dir, basename + JSON_EXT) - - """Annotation file priority: - PascalXML > YOLO - """ - if os.path.isfile(xml_path): - self.load_pascal_xml_by_filename(xml_path) - elif os.path.isfile(txt_path): - self.load_yolo_txt_by_filename(txt_path) - elif os.path.isfile(json_path): - self.load_create_ml_json_by_filename(json_path, file_path) - - else: - xml_path = os.path.splitext(file_path)[0] + XML_EXT - txt_path = os.path.splitext(file_path)[0] + TXT_EXT - json_path = os.path.splitext(file_path)[0] + JSON_EXT - - if os.path.isfile(xml_path): - self.load_pascal_xml_by_filename(xml_path) - elif os.path.isfile(txt_path): - self.load_yolo_txt_by_filename(txt_path) - elif os.path.isfile(json_path): - self.load_create_ml_json_by_filename(json_path, file_path) - + for suffix in LabelFileFormat.suffixes: + label_path = os.path.join(self.default_save_dir, basename + suffix) + if os.path.isfile(label_path): + self.load_label_by_filename(label_path) + break def resizeEvent(self, event): if self.canvas and not self.image.isNull()\ @@ -1321,7 +1260,7 @@ def open_annotation_dialog(self, _value=False): path = os.path.dirname(ustr(self.file_path))\ if self.file_path else '.' - if self.label_file_format == LabelFileFormat.PASCAL_VOC: + if self.label_file_format == PascalVoc: filters = "Open Annotation XML file (%s)" % ' '.join(['*.xml']) filename = ustr(QFileDialog.getOpenFileName(self, '%s - Choose a xml file' % __appname__, path, filters)) if filename: @@ -1329,7 +1268,7 @@ def open_annotation_dialog(self, _value=False): filename = filename[0] self.load_pascal_xml_by_filename(filename) - elif self.label_file_format == LabelFileFormat.CREATE_ML: + elif self.label_file_format == CreateML: filters = "Open Annotation JSON file (%s)" % ' '.join(['*.json']) filename = ustr(QFileDialog.getOpenFileName(self, '%s - Choose a json file' % __appname__, path, filters)) @@ -1358,8 +1297,6 @@ def open_dir_dialog(self, _value=False, dir_path=None, silent=False): self.last_open_dir = target_dir_path self.import_dir_images(target_dir_path) self.default_save_dir = target_dir_path - if self.file_path: - self.show_bounding_box_from_annotation_file(file_path=self.file_path) def import_dir_images(self, dir_path): if not self.may_continue() or not dir_path: @@ -1455,7 +1392,7 @@ def open_file(self, _value=False): return path = os.path.dirname(ustr(self.file_path)) if self.file_path else '.' formats = ['*.%s' % fmt.data().decode("ascii").lower() for fmt in QImageReader.supportedImageFormats()] - filters = "Image & Label files (%s)" % ' '.join(formats + ['*%s' % LabelFile.suffix]) + filters = "Image & Label files (%s)" % ' '.join(formats + ['*%s' % suffix for suffix in LabelFileFormat.suffixes]) filename,_ = QFileDialog.getOpenFileName(self, '%s - Choose Image or Label file' % __appname__, path, filters) if filename: if isinstance(filename, (tuple, list)): @@ -1485,10 +1422,10 @@ def save_file_as(self, _value=False): def save_file_dialog(self, remove_ext=True): caption = '%s - Choose File' % __appname__ - filters = 'File (*%s)' % LabelFile.suffix + filters = 'File (*%s)' % self.label_file_format.suffix open_dialog_path = self.current_path() dlg = QFileDialog(self, caption, open_dialog_path, filters) - dlg.setDefaultSuffix(LabelFile.suffix[1:]) + dlg.setDefaultSuffix(self.label_file_format.suffix[1:]) dlg.setAcceptMode(QFileDialog.AcceptSave) filename_without_extension = os.path.splitext(self.file_path)[0] dlg.selectFile(filename_without_extension) @@ -1616,44 +1553,18 @@ def load_predefined_classes(self, predef_classes_file): else: self.label_hist.append(line) - def load_pascal_xml_by_filename(self, xml_path): - if self.file_path is None: - return - if os.path.isfile(xml_path) is False: - return - - self.set_format(FORMAT_PASCALVOC) - - t_voc_parse_reader = PascalVocReader(xml_path) - shapes = t_voc_parse_reader.get_shapes() - self.load_labels(shapes) - self.canvas.verified = t_voc_parse_reader.verified - - def load_yolo_txt_by_filename(self, txt_path): - if self.file_path is None: - return - if os.path.isfile(txt_path) is False: - return - - self.set_format(FORMAT_YOLO) - t_yolo_parse_reader = YoloReader(txt_path, self.image) - shapes = t_yolo_parse_reader.get_shapes() - print(shapes) - self.load_labels(shapes) - self.canvas.verified = t_yolo_parse_reader.verified - - def load_create_ml_json_by_filename(self, json_path, file_path): + def load_label_by_filename(self, label_path): if self.file_path is None: return - if os.path.isfile(json_path) is False: + if os.path.isfile(label_path) is False: return - - self.set_format(FORMAT_CREATEML) - - create_ml_parse_reader = CreateMLReader(json_path, file_path) - shapes = create_ml_parse_reader.get_shapes() + suffix = os.path.splitext(label_path)[1] + format_id = LabelFileFormat.suffixes.index(suffix) + file_format = LabelFileFormat.formats[format_id] + self.set_format(file_format) + shapes = self.label_file_format.read(label_path, self.image) self.load_labels(shapes) - self.canvas.verified = create_ml_parse_reader.verified + self.canvas.verified = self.label_file_format.file_reader.verified def copy_previous_bounding_boxes(self): current_index = self.m_img_list.index(self.file_path) diff --git a/libs/create_ml_io.py b/libs/create_ml_io.py index 3aca8d676..b300b252a 100644 --- a/libs/create_ml_io.py +++ b/libs/create_ml_io.py @@ -4,25 +4,21 @@ from pathlib import Path from libs.constants import DEFAULT_ENCODING +from libs.io_abstract_class import FileReader, FileWriter import os JSON_EXT = '.json' ENCODE_METHOD = DEFAULT_ENCODING -class CreateMLWriter: - def __init__(self, folder_name, filename, img_size, shapes, output_file, database_src='Unknown', local_img_path=None): - self.folder_name = folder_name - self.filename = filename - self.database_src = database_src - self.img_size = img_size - self.box_list = [] - self.local_img_path = local_img_path - self.verified = False - self.shapes = shapes - self.output_file = output_file +class CreateMLWriter(FileWriter): - def write(self): + def __init__(self, img_folder_name, img_file_name, + img_shape, shapes, filename): + super().__init__(img_folder_name, + img_file_name, img_shape, shapes, filename) + + def save(self, target_file=None, class_list=None): if os.path.isfile(self.output_file): with open(self.output_file, "r") as file: input_data = file.read() @@ -93,18 +89,14 @@ def calculate_coordinates(self, x1, x2, y1, y2): return height, width, x, y -class CreateMLReader: +class CreateMLReader(FileReader): + def __init__(self, json_path, file_path): self.json_path = json_path - self.shapes = [] - self.verified = False - self.filename = os.path.basename(file_path) - try: - self.parse_json() - except ValueError: - print("JSON decoding failed") - - def parse_json(self): + self.filename = os.path.basename(json_path) + super().__init__(file_path) + + def parse_file(self): with open(self.json_path, "r") as file: input_data = file.read() @@ -117,9 +109,9 @@ def parse_json(self): if len(self.shapes) > 0: self.shapes = [] for image in output_list: - if image["image"] == self.filename: - for shape in image["annotations"]: - self.add_shape(shape["label"], shape["coordinates"]) + #if os.path.splitext(image["image"])[0] == os.path.splitext(self.filename)[0]: + for shape in image["annotations"]: + self.add_shape(shape["label"], shape["coordinates"]) def add_shape(self, label, bnd_box): x_min = bnd_box["x"] - (bnd_box["width"] / 2) @@ -131,5 +123,3 @@ def add_shape(self, label, bnd_box): points = [(x_min, y_min), (x_max, y_min), (x_max, y_max), (x_min, y_max)] self.shapes.append((label, points, None, None, True)) - def get_shapes(self): - return self.shapes diff --git a/libs/io_abstract_class.py b/libs/io_abstract_class.py new file mode 100644 index 000000000..b55efd747 --- /dev/null +++ b/libs/io_abstract_class.py @@ -0,0 +1,46 @@ +import abc + +class FileWriter(metaclass=abc.ABCMeta): + + def __init__(self, img_folder_name, img_file_name, + img_shape, shapes, filename): + self.folder_name = img_folder_name + self.filename = img_file_name + self.img_size = img_shape + self.box_list = [] + self.verified = False + self.shapes = shapes + self.output_file = filename + + @abc.abstractmethod + def save(self): + # save labelfile format at location file path + pass + +class FileReader(metaclass=abc.ABCMeta): + + def __init__(self, file_path): + self.shapes = [] + self.file_path = file_path + self.verified = False + + try: + self.parse_file() + except: + pass + + @abc.abstractmethod + def add_shape(self): + # append the shape to self.shapes + pass + + @abc.abstractmethod + def parse_file(self): + # parse the label file then add all the shapes into self.shapes + pass + + def get_shapes(self): + # return the shapes of bounding boxes + return self.shapes + + diff --git a/libs/labelFile.py b/libs/labelFile.py index 185570bcb..859e025b5 100644 --- a/libs/labelFile.py +++ b/libs/labelFile.py @@ -7,106 +7,46 @@ from PyQt4.QtGui import QImage import os.path -from enum import Enum - -from libs.create_ml_io import CreateMLWriter -from libs.pascal_voc_io import PascalVocWriter -from libs.pascal_voc_io import XML_EXT -from libs.yolo_io import YOLOWriter - - -class LabelFileFormat(Enum): - PASCAL_VOC = 1 - YOLO = 2 - CREATE_ML = 3 - +from libs.labelFileFormat import LabelFileFormat, PascalVoc, Yolo, CreateML class LabelFileError(Exception): pass - class LabelFile(object): - # It might be changed as window creates. By default, using XML ext - # suffix = '.lif' - suffix = XML_EXT def __init__(self, filename=None): + self.label_file_format = PascalVoc self.shapes = () self.image_path = None self.image_data = None self.verified = False - def save_create_ml_format(self, filename, shapes, image_path, image_data, class_list, line_color=None, fill_color=None, database_src=None): - img_folder_name = os.path.basename(os.path.dirname(image_path)) - img_file_name = os.path.basename(image_path) - - image = QImage() - image.load(image_path) - image_shape = [image.height(), image.width(), - 1 if image.isGrayscale() else 3] - writer = CreateMLWriter(img_folder_name, img_file_name, - image_shape, shapes, filename, local_img_path=image_path) - writer.verified = self.verified - writer.write() - return - - - def save_pascal_voc_format(self, filename, shapes, image_path, image_data, - line_color=None, fill_color=None, database_src=None): - img_folder_path = os.path.dirname(image_path) - img_folder_name = os.path.split(img_folder_path)[-1] - img_file_name = os.path.basename(image_path) - # imgFileNameWithoutExt = os.path.splitext(img_file_name)[0] - # Read from file path because self.imageData might be empty if saving to - # Pascal format - if isinstance(image_data, QImage): - image = image_data - else: - image = QImage() - image.load(image_path) - image_shape = [image.height(), image.width(), - 1 if image.isGrayscale() else 3] - writer = PascalVocWriter(img_folder_name, img_file_name, - image_shape, local_img_path=image_path) - writer.verified = self.verified - - for shape in shapes: - points = shape['points'] - label = shape['label'] - # Add Chris - difficult = int(shape['difficult']) - bnd_box = LabelFile.convert_points_to_bnd_box(points) - writer.add_bnd_box(bnd_box[0], bnd_box[1], bnd_box[2], bnd_box[3], label, difficult) - - writer.save(target_file=filename) - return - - def save_yolo_format(self, filename, shapes, image_path, image_data, class_list, - line_color=None, fill_color=None, database_src=None): + def save(self, filename, shapes, image_path, image_data, class_list): + if os.path.splitext(filename)[1] != self.label_file_format.suffix: + filename += self.label_file_format.suffix + label_filename = filename img_folder_path = os.path.dirname(image_path) img_folder_name = os.path.split(img_folder_path)[-1] img_file_name = os.path.basename(image_path) - # imgFileNameWithoutExt = os.path.splitext(img_file_name)[0] - # Read from file path because self.imageData might be empty if saving to - # Pascal format if isinstance(image_data, QImage): image = image_data else: image = QImage() image.load(image_path) - image_shape = [image.height(), image.width(), + img_shape = [image.height(), image.width(), 1 if image.isGrayscale() else 3] - writer = YOLOWriter(img_folder_name, img_file_name, - image_shape, local_img_path=image_path) + writer = self.label_file_format.writer(img_folder_name, img_file_name, img_shape, shapes, label_filename) writer.verified = self.verified for shape in shapes: points = shape['points'] label = shape['label'] - # Add Chris difficult = int(shape['difficult']) bnd_box = LabelFile.convert_points_to_bnd_box(points) - writer.add_bnd_box(bnd_box[0], bnd_box[1], bnd_box[2], bnd_box[3], label, difficult) + try: + writer.add_bnd_box(bnd_box[0], bnd_box[1], bnd_box[2], bnd_box[3], label, difficult) + except: + pass writer.save(target_file=filename, class_list=class_list) return @@ -146,7 +86,7 @@ def save(self, filename, shapes, imagePath, imageData, lineColor=None, fillColor @staticmethod def is_label_file(filename): file_suffix = os.path.splitext(filename)[1].lower() - return file_suffix == LabelFile.suffix + return (file_suffix in LabelFileFormat.suffixes) @staticmethod def convert_points_to_bnd_box(points): diff --git a/libs/labelFileFormat.py b/libs/labelFileFormat.py new file mode 100644 index 000000000..774b1658c --- /dev/null +++ b/libs/labelFileFormat.py @@ -0,0 +1,47 @@ +from libs.pascal_voc_io import PascalVocReader, PascalVocWriter, XML_EXT +from libs.yolo_io import YoloReader, YOLOWriter, TXT_EXT +from libs.create_ml_io import CreateMLReader, CreateMLWriter, JSON_EXT + +class LabelFileFormat(object): + + formats = [] + suffixes = [] + + def __init__(self, reader, writer, suffix, text, icon): + self.reader = reader + self.writer = writer + self.suffix = suffix + self.text = text + self.icon = icon + self.formats.append(self) + self.suffixes.append(self.suffix) + + def read(self, *args, **kwargs): + self.file_reader = self.reader(*args, **kwargs) + return self.file_reader.get_shapes() + + def write(self, *args, **kwargs): + self.file_writer = self.writer(*args, **kwargs) + self.file_writer.save() + + def __eq__(self, other): + return (self.reader == other.reader and self.writer == other.writer) + + +PascalVoc = LabelFileFormat(PascalVocReader, + PascalVocWriter, + XML_EXT, + 'PascalVOC', + 'format_voc') + +Yolo = LabelFileFormat(YoloReader, + YOLOWriter, + TXT_EXT, + 'Yolo', + 'format_yolo') + +CreateML = LabelFileFormat(CreateMLReader, + CreateMLWriter, + JSON_EXT, + 'CreateML', + 'format_createml') diff --git a/libs/pascal_voc_io.py b/libs/pascal_voc_io.py index d8f7d690b..7711511e0 100644 --- a/libs/pascal_voc_io.py +++ b/libs/pascal_voc_io.py @@ -7,21 +7,21 @@ import codecs from libs.constants import DEFAULT_ENCODING from libs.ustr import ustr +from libs.io_abstract_class import FileReader, FileWriter XML_EXT = '.xml' ENCODE_METHOD = DEFAULT_ENCODING -class PascalVocWriter: +class PascalVocWriter(FileWriter): - def __init__(self, folder_name, filename, img_size, database_src='Unknown', local_img_path=None): - self.folder_name = folder_name - self.filename = filename - self.database_src = database_src - self.img_size = img_size + def __init__(self, img_folder_name, img_file_name, + img_shape, shapes, filename): + super().__init__(img_folder_name, + img_file_name, img_shape, shapes, filename) + self.database_src = "" self.box_list = [] - self.local_img_path = local_img_path - self.verified = False + self.local_img_path = None def prettify(self, elem): """ @@ -109,7 +109,7 @@ def append_objects(self, top): y_max = SubElement(bnd_box, 'ymax') y_max.text = str(each_object['ymax']) - def save(self, target_file=None): + def save(self, target_file=None, class_list=None): root = self.gen_xml() self.append_objects(root) out_file = None @@ -124,18 +124,12 @@ def save(self, target_file=None): out_file.close() -class PascalVocReader: +class PascalVocReader(FileReader): - def __init__(self, file_path): + def __init__(self, file_path, image = None): # shapes type: # [labbel, [(x1,y1), (x2,y2), (x3,y3), (x4,y4)], color, color, difficult] - self.shapes = [] - self.file_path = file_path - self.verified = False - try: - self.parse_xml() - except: - pass + super().__init__(file_path) def get_shapes(self): return self.shapes @@ -148,7 +142,7 @@ def add_shape(self, label, bnd_box, difficult): points = [(x_min, y_min), (x_max, y_min), (x_max, y_max), (x_min, y_max)] self.shapes.append((label, points, None, None, difficult)) - def parse_xml(self): + def parse_file(self): assert self.file_path.endswith(XML_EXT), "Unsupported file format" parser = etree.XMLParser(encoding=ENCODE_METHOD) xml_tree = ElementTree.parse(self.file_path, parser=parser).getroot() diff --git a/libs/yolo_io.py b/libs/yolo_io.py index 192e2c785..ed09b1b51 100644 --- a/libs/yolo_io.py +++ b/libs/yolo_io.py @@ -4,20 +4,17 @@ import os from libs.constants import DEFAULT_ENCODING +from libs.io_abstract_class import FileReader, FileWriter TXT_EXT = '.txt' ENCODE_METHOD = DEFAULT_ENCODING class YOLOWriter: - def __init__(self, folder_name, filename, img_size, database_src='Unknown', local_img_path=None): - self.folder_name = folder_name - self.filename = filename - self.database_src = database_src - self.img_size = img_size - self.box_list = [] - self.local_img_path = local_img_path - self.verified = False + def __init__(self, img_folder_name, img_file_name, + img_shape, shapes, filename): + super().__init__(img_folder_name, + img_file_name, img_shape, shapes, filename) def add_bnd_box(self, x_min, y_min, x_max, y_max, name, difficult): bnd_box = {'xmin': x_min, 'ymin': y_min, 'xmax': x_max, 'ymax': y_max} @@ -78,12 +75,11 @@ def save(self, class_list=[], target_file=None): -class YoloReader: +class YoloReader(FileReader): - def __init__(self, file_path, image, class_list_path=None): + def __init__(self, file_path, image): # shapes type: # [labbel, [(x1,y1), (x2,y2), (x3,y3), (x4,y4)], color, color, difficult] - self.shapes = [] self.file_path = file_path if class_list_path is None: @@ -103,12 +99,7 @@ def __init__(self, file_path, image, class_list_path=None): 1 if image.isGrayscale() else 3] self.img_size = img_size - - self.verified = False - # try: - self.parse_yolo_format() - # except: - # pass + super().__init__(file_path) def get_shapes(self): return self.shapes diff --git a/notes.txt b/notes.txt new file mode 100644 index 000000000..444670705 --- /dev/null +++ b/notes.txt @@ -0,0 +1,7 @@ +Unit testing: + python -m unittest tests/test_xx.py + python -m unittest tests/*.py + +Env: + source venv/bin/activate + diff --git a/tests/.gitignore b/tests/.gitignore index a6535f35d..9e34968eb 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -1 +1,2 @@ test.xml +tests.json diff --git a/tests/test_io.py b/tests/test_io.py index 7bc31b3af..238c0b19a 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -12,11 +12,12 @@ def test_upper(self): from pascal_voc_io import PascalVocReader # Test Write/Read - writer = PascalVocWriter('tests', 'test', (512, 512, 1), local_img_path='tests/test.512.512.bmp') - difficult = 1 - writer.add_bnd_box(60, 40, 430, 504, 'person', difficult) - writer.add_bnd_box(113, 40, 450, 403, 'face', difficult) - writer.save('tests/test.xml') + shapes = [ + ['person', [(60, 40), (430, 40), (430, 504), (60, 504)]], + ['face', [(113, 40), (450, 40), (450, 403), (113, 403)]] + ] + writer = PascalVocWriter('tests', 'test', (512, 512, 1), shapes, 'tests/test.xml') + writer.save() reader = PascalVocReader('tests/test.xml') shapes = reader.get_shapes() @@ -48,11 +49,10 @@ def test_a_write(self): shapes = [person, face] output_file = dir_name + "/tests.json" - writer = CreateMLWriter('tests', 'test.512.512.bmp', (512, 512, 1), shapes, output_file, - local_img_path='tests/test.512.512.bmp') + writer = CreateMLWriter('tests', 'test.512.512.bmp', (512, 512, 1), shapes, output_file) writer.verified = True - writer.write() + writer.save() # check written json with open(output_file, "r") as file: