forked from KerasKorea/KerasObjectDetector
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request KerasKorea#39 from minus31/gen
Resolved issues (KerasKorea#15, KerasKorea#33) Data generator for RetinaNet
- Loading branch information
Showing
2 changed files
with
304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import os | ||
import errno | ||
import hashlib | ||
import requests | ||
from tqdm import tqdm | ||
|
||
|
||
def makedirs(path): | ||
"""Create directory recursively if not exists. | ||
Similar to `makedir -p`, you can skip checking existence before this function. | ||
Parameters | ||
---------- | ||
path : str | ||
Path of the desired dir | ||
""" | ||
|
||
try: | ||
os.makedirs(path) | ||
except OSError as exc: | ||
if exc.errno != errno.EEXIST: | ||
raise | ||
|
||
"""Download files with progress bar.""" | ||
|
||
def check_sha1(filename, sha1_hash): | ||
"""Check whether the sha1 hash of the file content matches the expected hash. | ||
Parameters | ||
---------- | ||
filename : str | ||
Path to the file. | ||
sha1_hash : str | ||
Expected sha1 hash in hexadecimal digits. | ||
Returns | ||
------- | ||
bool | ||
Whether the file content matches the expected hash. | ||
""" | ||
sha1 = hashlib.sha1() | ||
with open(filename, 'rb') as f: | ||
while True: | ||
data = f.read(1048576) | ||
if not data: | ||
break | ||
sha1.update(data) | ||
|
||
sha1_file = sha1.hexdigest() | ||
l = min(len(sha1_file), len(sha1_hash)) | ||
return sha1.hexdigest()[0:l] == sha1_hash[0:l] | ||
|
||
def download(url, path=None, overwrite=False, sha1_hash=None): | ||
"""Download an given URL | ||
Parameters | ||
---------- | ||
url : str | ||
URL to download | ||
path : str, optional | ||
Destination path to store downloaded file. By default stores to the | ||
current directory with same name as in url. | ||
overwrite : bool, optional | ||
Whether to overwrite destination file if already exists. | ||
sha1_hash : str, optional | ||
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified | ||
but doesn't match. | ||
Returns | ||
------- | ||
str | ||
The file path of the downloaded file. | ||
""" | ||
if path is None: | ||
fname = url.split('/')[-1] | ||
else: | ||
path = os.path.expanduser(path) | ||
if os.path.isdir(path): | ||
fname = os.path.join(path, url.split('/')[-1]) | ||
else: | ||
fname = path | ||
|
||
if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): | ||
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) | ||
if not os.path.exists(dirname): | ||
os.makedirs(dirname) | ||
|
||
print('Downloading %s from %s...'%(fname, url)) | ||
r = requests.get(url, stream=True) | ||
if r.status_code != 200: | ||
raise RuntimeError("Failed downloading url %s"%url) | ||
total_length = r.headers.get('content-length') | ||
with open(fname, 'wb') as f: | ||
if total_length is None: # no content length header | ||
for chunk in r.iter_content(chunk_size=1024): | ||
if chunk: # filter out keep-alive new chunks | ||
f.write(chunk) | ||
else: | ||
total_length = int(total_length) | ||
for chunk in tqdm(r.iter_content(chunk_size=1024), | ||
total=int(total_length / 1024. + 0.5), | ||
unit='KB', unit_scale=False, dynamic_ncols=True): | ||
f.write(chunk) | ||
|
||
if sha1_hash and not check_sha1(fname, sha1_hash): | ||
raise UserWarning('File {} is downloaded but the content hash does not match. ' \ | ||
'The repo may be outdated or download may be incomplete. ' \ | ||
'If the "repo_url" is overridden, consider switching to ' \ | ||
'the default repo.'.format(fname)) | ||
|
||
return fname | ||
|
||
"""Prepare PASCAL VOC datasets""" | ||
import os | ||
import shutil | ||
import argparse | ||
import tarfile | ||
|
||
_TARGET_DIR = os.path.expanduser('~/.yolk/datasets/voc') | ||
|
||
##################################################################################### | ||
# Download and extract VOC datasets into ``path`` | ||
|
||
def download_voc(path, overwrite=False): | ||
_DOWNLOAD_URLS = [ | ||
('http://bit.ly/yolk_voc_train_val2007_tar', | ||
'34ed68851bce2a36e2a223fa52c661d592c66b3c'), | ||
('http://bit.ly/yolk_voc_train_val2012_tar', | ||
'41a8d6e12baa5ab18ee7f8f8029b9e11805b4ef1'), | ||
('http://bit.ly/yolk_voc_test2012_tar', | ||
'4e443f8a2eca6b1dac8a6c57641b67dd40621a49')] | ||
makedirs(path) | ||
for url, checksum in _DOWNLOAD_URLS: | ||
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) | ||
# extract | ||
with tarfile.open(filename) as tar: | ||
tar.extractall(path=path) | ||
|
||
|
||
##################################################################################### | ||
# Download and extract the VOC augmented segmentation dataset into ``path`` | ||
|
||
def download_pascal(download_dir="~/VOCdevkit", overwrite=True, no_download=False): | ||
path = os.path.expanduser(download_dir) | ||
if not os.path.isdir(path) or not os.path.isdir(os.path.join(path, 'VOC2007')) \ | ||
or not os.path.isdir(os.path.join(path, 'VOC2012')): | ||
if no_download: | ||
raise ValueError(('{} is not a valid directory, make sure it is present.' | ||
' Or you should not disable "--no-download" to grab it'.format(path))) | ||
else: | ||
download_voc(path, overwrite=overwrite) | ||
shutil.move(os.path.join(path, 'VOCdevkit', 'VOC2007'), os.path.join(path, 'VOC2007')) | ||
shutil.move(os.path.join(path, 'VOCdevkit', 'VOC2012'), os.path.join(path, 'VOC2012')) | ||
shutil.rmtree(os.path.join(path, 'VOCdevkit')) | ||
|
||
# make symlink | ||
makedirs(os.path.expanduser('~/.yolk/datasets')) | ||
if os.path.isdir(_TARGET_DIR): | ||
os.remove(_TARGET_DIR) | ||
os.symlink(path, _TARGET_DIR) | ||
print("Downloaded!!!") | ||
|
||
if __name__ == "__main__": | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
""" | ||
Copyright 2017-2018 Fizyr (https://fizyr.com) | ||
""" | ||
import argparse | ||
import os | ||
import sys | ||
import warnings | ||
import keras | ||
import keras.preprocessing.image | ||
import tensorflow as tf | ||
|
||
# Allow relative imports when being executed as script. | ||
if __name__ == "__main__" and __package__ is None: | ||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) | ||
import keras_retinanet.bin # noqa: F401 | ||
__package__ = "keras_retinanet.bin" | ||
|
||
from ..preprocessing.pascal_voc import PascalVocGenerator | ||
from ..utils.anchors import make_shapes_callback | ||
from ..utils.config import read_config_file, parse_anchor_parameters | ||
from ..utils.keras_version import check_keras_version | ||
from ..utils.transform import random_transform_generator | ||
from ..utils.image import random_visual_effect_generator | ||
from ..preprocessing.download import download_pascal | ||
|
||
def make_generators(batch_size=32, image_min_side=800, image_max_side=1333, preprocess_image=lambda x : x / 255., | ||
random_transform=True, dataset_type='voc', vesion="2012"): | ||
""" Create generators for training and validation. | ||
Args/ | ||
args : parseargs object containing configuration for generators. | ||
preprocess_image : Function that preprocesses an image for the network. | ||
""" | ||
|
||
common_args = { | ||
'batch_size' : batch_size, | ||
'image_min_side' : image_min_side, | ||
'image_max_side' : image_max_side, | ||
'preprocess_image' : preprocess_image, | ||
} | ||
|
||
# create random transform generator for augmenting training data | ||
if random_transform: | ||
transform_generator = random_transform_generator( | ||
min_rotation=-0.1, | ||
max_rotation=0.1, | ||
min_translation=(-0.1, -0.1), | ||
max_translation=(0.1, 0.1), | ||
min_shear=-0.1, | ||
max_shear=0.1, | ||
min_scaling=(0.9, 0.9), | ||
max_scaling=(1.1, 1.1), | ||
flip_x_chance=0.5, | ||
flip_y_chance=0.5, | ||
) | ||
visual_effect_generator = random_visual_effect_generator( | ||
contrast_range=(0.9, 1.1), | ||
brightness_range=(-.1, .1), | ||
hue_range=(-0.05, 0.05), | ||
saturation_range=(0.95, 1.05) | ||
) | ||
else: | ||
transform_generator = random_transform_generator(flip_x_chance=0.5) | ||
visual_effect_generator = None | ||
|
||
|
||
# Dataset path | ||
|
||
dataset_path = os.path.join(os.path.expanduser("~"), ".yolk/datasets/") | ||
|
||
if dataset_type == 'coco': | ||
# import here to prevent unnecessary dependency on cocoapi | ||
|
||
if not os.path.exists(os.path.join(dataset_path, dataset_type)): | ||
# download_coco() | ||
pass | ||
|
||
from ..preprocessing.coco import CocoGenerator | ||
|
||
train_generator = CocoGenerator( | ||
os.path.join(dataset_path, dataset_type), | ||
'train2017', | ||
transform_generator=transform_generator, | ||
visual_effect_generator=visual_effect_generator, | ||
**common_args | ||
) | ||
|
||
validation_generator = CocoGenerator( | ||
os.path.join(dataset_path, dataset_type), | ||
'val2017', | ||
shuffle_groups=False, | ||
**common_args | ||
) | ||
|
||
elif dataset_type == 'voc': | ||
|
||
if not os.path.exists(os.path.join(dataset_path, dataset_type)): | ||
download_pascal() | ||
|
||
train_generator = PascalVocGenerator( | ||
os.path.join(dataset_path, dataset_type), | ||
'trainval', | ||
transform_generator=transform_generator, | ||
visual_effect_generator=visual_effect_generator, | ||
**common_args | ||
) | ||
|
||
validation_generator = PascalVocGenerator( | ||
os.path.join(dataset_path, dataset_type), | ||
'test', | ||
shuffle_groups=False, | ||
**common_args | ||
) | ||
|
||
else: | ||
raise ValueError('Invalid data type received: {}'.format(dataset_type)) | ||
|
||
return train_generator, validation_generator | ||
|
||
|
||
def make_generator_test(): | ||
""" | ||
Testing make_generator fucntion | ||
usage : python make_generator.py --test True | ||
""" | ||
train_gen, val_gen = make_generators() | ||
|
||
sample = train_gen[0] | ||
|
||
assert sample[0].shape[0] == sample[1][0].shape[0] | ||
print("generator created sucsessfully!") | ||
|
||
import argparse | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--test', type=bool, default=False, help='dataset directory on disk') | ||
args = parser.parse_args() | ||
|
||
if args.test: | ||
make_generator_test() | ||
|
||
|