Skip to content

Commit

Permalink
Merge pull request KerasKorea#39 from minus31/gen
Browse files Browse the repository at this point in the history
Resolved issues (KerasKorea#15, KerasKorea#33) Data generator for RetinaNet
  • Loading branch information
MijeongJeon authored Oct 20, 2019
2 parents c32a220 + 3a8d7ef commit b8f1d5c
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 0 deletions.
160 changes: 160 additions & 0 deletions keras_retinanet/preprocessing/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import os
import errno
import hashlib
import requests
from tqdm import tqdm


def makedirs(path):
"""Create directory recursively if not exists.
Similar to `makedir -p`, you can skip checking existence before this function.
Parameters
----------
path : str
Path of the desired dir
"""

try:
os.makedirs(path)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise

"""Download files with progress bar."""

def check_sha1(filename, sha1_hash):
"""Check whether the sha1 hash of the file content matches the expected hash.
Parameters
----------
filename : str
Path to the file.
sha1_hash : str
Expected sha1 hash in hexadecimal digits.
Returns
-------
bool
Whether the file content matches the expected hash.
"""
sha1 = hashlib.sha1()
with open(filename, 'rb') as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)

sha1_file = sha1.hexdigest()
l = min(len(sha1_file), len(sha1_hash))
return sha1.hexdigest()[0:l] == sha1_hash[0:l]

def download(url, path=None, overwrite=False, sha1_hash=None):
"""Download an given URL
Parameters
----------
url : str
URL to download
path : str, optional
Destination path to store downloaded file. By default stores to the
current directory with same name as in url.
overwrite : bool, optional
Whether to overwrite destination file if already exists.
sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
but doesn't match.
Returns
-------
str
The file path of the downloaded file.
"""
if path is None:
fname = url.split('/')[-1]
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
fname = os.path.join(path, url.split('/')[-1])
else:
fname = path

if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)

print('Downloading %s from %s...'%(fname, url))
r = requests.get(url, stream=True)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s"%url)
total_length = r.headers.get('content-length')
with open(fname, 'wb') as f:
if total_length is None: # no content length header
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
else:
total_length = int(total_length)
for chunk in tqdm(r.iter_content(chunk_size=1024),
total=int(total_length / 1024. + 0.5),
unit='KB', unit_scale=False, dynamic_ncols=True):
f.write(chunk)

if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match. ' \
'The repo may be outdated or download may be incomplete. ' \
'If the "repo_url" is overridden, consider switching to ' \
'the default repo.'.format(fname))

return fname

"""Prepare PASCAL VOC datasets"""
import os
import shutil
import argparse
import tarfile

_TARGET_DIR = os.path.expanduser('~/.yolk/datasets/voc')

#####################################################################################
# Download and extract VOC datasets into ``path``

def download_voc(path, overwrite=False):
_DOWNLOAD_URLS = [
('http://bit.ly/yolk_voc_train_val2007_tar',
'34ed68851bce2a36e2a223fa52c661d592c66b3c'),
('http://bit.ly/yolk_voc_train_val2012_tar',
'41a8d6e12baa5ab18ee7f8f8029b9e11805b4ef1'),
('http://bit.ly/yolk_voc_test2012_tar',
'4e443f8a2eca6b1dac8a6c57641b67dd40621a49')]
makedirs(path)
for url, checksum in _DOWNLOAD_URLS:
filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
# extract
with tarfile.open(filename) as tar:
tar.extractall(path=path)


#####################################################################################
# Download and extract the VOC augmented segmentation dataset into ``path``

def download_pascal(download_dir="~/VOCdevkit", overwrite=True, no_download=False):
path = os.path.expanduser(download_dir)
if not os.path.isdir(path) or not os.path.isdir(os.path.join(path, 'VOC2007')) \
or not os.path.isdir(os.path.join(path, 'VOC2012')):
if no_download:
raise ValueError(('{} is not a valid directory, make sure it is present.'
' Or you should not disable "--no-download" to grab it'.format(path)))
else:
download_voc(path, overwrite=overwrite)
shutil.move(os.path.join(path, 'VOCdevkit', 'VOC2007'), os.path.join(path, 'VOC2007'))
shutil.move(os.path.join(path, 'VOCdevkit', 'VOC2012'), os.path.join(path, 'VOC2012'))
shutil.rmtree(os.path.join(path, 'VOCdevkit'))

# make symlink
makedirs(os.path.expanduser('~/.yolk/datasets'))
if os.path.isdir(_TARGET_DIR):
os.remove(_TARGET_DIR)
os.symlink(path, _TARGET_DIR)
print("Downloaded!!!")

if __name__ == "__main__":
pass
144 changes: 144 additions & 0 deletions keras_retinanet/utils/make_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Copyright 2017-2018 Fizyr (https://fizyr.com)
"""
import argparse
import os
import sys
import warnings
import keras
import keras.preprocessing.image
import tensorflow as tf

# Allow relative imports when being executed as script.
if __name__ == "__main__" and __package__ is None:
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
import keras_retinanet.bin # noqa: F401
__package__ = "keras_retinanet.bin"

from ..preprocessing.pascal_voc import PascalVocGenerator
from ..utils.anchors import make_shapes_callback
from ..utils.config import read_config_file, parse_anchor_parameters
from ..utils.keras_version import check_keras_version
from ..utils.transform import random_transform_generator
from ..utils.image import random_visual_effect_generator
from ..preprocessing.download import download_pascal

def make_generators(batch_size=32, image_min_side=800, image_max_side=1333, preprocess_image=lambda x : x / 255.,
random_transform=True, dataset_type='voc', vesion="2012"):
""" Create generators for training and validation.
Args/
args : parseargs object containing configuration for generators.
preprocess_image : Function that preprocesses an image for the network.
"""

common_args = {
'batch_size' : batch_size,
'image_min_side' : image_min_side,
'image_max_side' : image_max_side,
'preprocess_image' : preprocess_image,
}

# create random transform generator for augmenting training data
if random_transform:
transform_generator = random_transform_generator(
min_rotation=-0.1,
max_rotation=0.1,
min_translation=(-0.1, -0.1),
max_translation=(0.1, 0.1),
min_shear=-0.1,
max_shear=0.1,
min_scaling=(0.9, 0.9),
max_scaling=(1.1, 1.1),
flip_x_chance=0.5,
flip_y_chance=0.5,
)
visual_effect_generator = random_visual_effect_generator(
contrast_range=(0.9, 1.1),
brightness_range=(-.1, .1),
hue_range=(-0.05, 0.05),
saturation_range=(0.95, 1.05)
)
else:
transform_generator = random_transform_generator(flip_x_chance=0.5)
visual_effect_generator = None


# Dataset path

dataset_path = os.path.join(os.path.expanduser("~"), ".yolk/datasets/")

if dataset_type == 'coco':
# import here to prevent unnecessary dependency on cocoapi

if not os.path.exists(os.path.join(dataset_path, dataset_type)):
# download_coco()
pass

from ..preprocessing.coco import CocoGenerator

train_generator = CocoGenerator(
os.path.join(dataset_path, dataset_type),
'train2017',
transform_generator=transform_generator,
visual_effect_generator=visual_effect_generator,
**common_args
)

validation_generator = CocoGenerator(
os.path.join(dataset_path, dataset_type),
'val2017',
shuffle_groups=False,
**common_args
)

elif dataset_type == 'voc':

if not os.path.exists(os.path.join(dataset_path, dataset_type)):
download_pascal()

train_generator = PascalVocGenerator(
os.path.join(dataset_path, dataset_type),
'trainval',
transform_generator=transform_generator,
visual_effect_generator=visual_effect_generator,
**common_args
)

validation_generator = PascalVocGenerator(
os.path.join(dataset_path, dataset_type),
'test',
shuffle_groups=False,
**common_args
)

else:
raise ValueError('Invalid data type received: {}'.format(dataset_type))

return train_generator, validation_generator


def make_generator_test():
"""
Testing make_generator fucntion
usage : python make_generator.py --test True
"""
train_gen, val_gen = make_generators()

sample = train_gen[0]

assert sample[0].shape[0] == sample[1][0].shape[0]
print("generator created sucsessfully!")

import argparse

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--test', type=bool, default=False, help='dataset directory on disk')
args = parser.parse_args()

if args.test:
make_generator_test()


0 comments on commit b8f1d5c

Please sign in to comment.