diff --git a/keras_segmentation/predict.py b/keras_segmentation/predict.py index 622b82adf..f88824bff 100644 --- a/keras_segmentation/predict.py +++ b/keras_segmentation/predict.py @@ -10,17 +10,31 @@ from time import time from .train import find_latest_checkpoint -from .data_utils.data_loader import get_image_array, get_segmentation_array,\ +from .data_utils.data_loader import get_image_array, get_segmentation_array, \ DATA_LOADER_SEED, class_colors, get_pairs_from_paths from .models.config import IMAGE_ORDERING - random.seed(DATA_LOADER_SEED) -def model_from_checkpoint_path(checkpoints_path): +def load_segmentation_model(model_config: dict, weights: str): + from .models.all_models import model_from_name + model = model_from_name[model_config['model_class']]( + model_config['n_classes'], input_height=model_config['input_height'], + input_width=model_config['input_width']) + print("loaded weights ", weights) + status = model.load_weights(weights) + + if status is not None: + status.expect_partial() + + return model + + +def model_from_checkpoint_path(checkpoints_path): from .models.all_models import model_from_name + assert (os.path.isfile(checkpoints_path+"_config.json") ), "Checkpoint not found." model_config = json.loads( @@ -76,7 +90,8 @@ def get_legends(class_names, colors=class_colors): def overlay_seg_image(inp_img, seg_img): orininal_h = inp_img.shape[0] orininal_w = inp_img.shape[1] - seg_img = cv2.resize(seg_img, (orininal_w, orininal_h), interpolation=cv2.INTER_NEAREST) + seg_img = cv2.resize(seg_img, (orininal_w, orininal_h), + interpolation=cv2.INTER_NEAREST) fused_img = (inp_img/2 + seg_img/2).astype('uint8') return fused_img @@ -108,10 +123,12 @@ def visualize_segmentation(seg_arr, inp_img=None, n_classes=None, if inp_img is not None: original_h = inp_img.shape[0] original_w = inp_img.shape[1] - seg_img = cv2.resize(seg_img, (original_w, original_h), interpolation=cv2.INTER_NEAREST) + seg_img = cv2.resize(seg_img, (original_w, original_h), + interpolation=cv2.INTER_NEAREST) if (prediction_height is not None) and (prediction_width is not None): - seg_img = cv2.resize(seg_img, (prediction_width, prediction_height), interpolation=cv2.INTER_NEAREST) + seg_img = cv2.resize( + seg_img, (prediction_width, prediction_height), interpolation=cv2.INTER_NEAREST) if inp_img is not None: inp_img = cv2.resize(inp_img, (prediction_width, prediction_height)) @@ -139,13 +156,14 @@ def predict(model=None, inp=None, out_fname=None, model = model_from_checkpoint_path(checkpoints_path) assert (inp is not None) - assert ((type(inp) is np.ndarray) or isinstance(inp, six.string_types)),\ + assert ((type(inp) is np.ndarray) or isinstance(inp, six.string_types)), \ "Input should be the CV image or the input file name" if isinstance(inp, six.string_types): inp = cv2.imread(inp, read_image_type) - assert (len(inp.shape) == 3 or len(inp.shape) == 1 or len(inp.shape) == 4), "Image should be h,w,3 " + assert (len(inp.shape) == 3 or len(inp.shape) == + 1 or len(inp.shape) == 4), "Image should be h,w,3 " output_width = model.output_width output_height = model.output_height @@ -193,7 +211,6 @@ def predict_multiple(model=None, inps=None, inp_dir=None, out_dir=None, if not os.path.exists(out_dir): os.makedirs(out_dir) - for i, inp in enumerate(tqdm(inps)): if out_dir is None: out_fname = None @@ -235,7 +252,7 @@ def predict_video(model=None, inp=None, output=None, n_classes = model.n_classes cap, video, fps = set_video(inp, output) - while(cap.isOpened()): + while (cap.isOpened()): prev_time = time() ret, frame = cap.read() if frame is not None: @@ -248,7 +265,7 @@ def predict_video(model=None, inp=None, output=None, class_names=class_names, prediction_width=prediction_width, prediction_height=prediction_height - ) + ) else: break print("FPS: {}".format(1/(time() - prev_time))) @@ -268,14 +285,14 @@ def evaluate(model=None, inp_images=None, annotations=None, inp_images_dir=None, annotations_dir=None, checkpoints_path=None, read_image_type=1): if model is None: - assert (checkpoints_path is not None),\ - "Please provide the model or the checkpoints_path" + assert (checkpoints_path is not None), \ + "Please provide the model or the checkpoints_path" model = model_from_checkpoint_path(checkpoints_path) if inp_images is None: - assert (inp_images_dir is not None),\ - "Please provide inp_images or inp_images_dir" - assert (annotations_dir is not None),\ + assert (inp_images_dir is not None), \ + "Please provide inp_images or inp_images_dir" + assert (annotations_dir is not None), \ "Please provide inp_images or inp_images_dir" paths = get_pairs_from_paths(inp_images_dir, annotations_dir) diff --git a/keras_segmentation/train.py b/keras_segmentation/train.py index 27e37f65d..fefb36a38 100755 --- a/keras_segmentation/train.py +++ b/keras_segmentation/train.py @@ -10,6 +10,7 @@ import glob import sys + def find_latest_checkpoint(checkpoints_path, fail_safe=True): # This is legacy code, there should always be a "checkpoint" file in your directory @@ -41,6 +42,7 @@ def get_epoch_number_from_path(path): return latest_epoch_checkpoint + def masked_categorical_crossentropy(gt, pr): from keras.losses import categorical_crossentropy mask = 1 - gt[:, :, 0] @@ -87,7 +89,7 @@ def train(model, read_image_type=1 # cv2.IMREAD_COLOR = 1 (rgb), # cv2.IMREAD_GRAYSCALE = 0, # cv2.IMREAD_UNCHANGED = -1 (4 channels like RGBA) - ): + ): from .models.all_models import model_from_name # check if user gives model name instead of the model object if isinstance(model, six.string_types): @@ -124,7 +126,7 @@ def train(model, config_file = checkpoints_path + "_config.json" dir_name = os.path.dirname(config_file) - if ( not os.path.exists(dir_name) ) and len( dir_name ) > 0 : + if (not os.path.exists(dir_name)) and len(dir_name) > 0: os.makedirs(dir_name) with open(config_file, "w") as f: @@ -179,14 +181,14 @@ def train(model, other_inputs_paths=other_inputs_paths, preprocessing=preprocessing, read_image_type=read_image_type) - if callbacks is None and (not checkpoints_path is None) : + if callbacks is None and (not checkpoints_path is None): default_callback = ModelCheckpoint( - filepath=checkpoints_path + ".{epoch:05d}", - save_weights_only=True, - verbose=True - ) + filepath=checkpoints_path + ".{epoch:05d}" + ".weights.h5", + save_weights_only=True, + verbose=True + ) - if sys.version_info[0] < 3: # for pyhton 2 + if sys.version_info[0] < 3: # for pyhton 2 default_callback = CheckpointsCallback(checkpoints_path) callbacks = [ diff --git a/setup.py b/setup.py index 44e38555a..9c067afe0 100644 --- a/setup.py +++ b/setup.py @@ -5,8 +5,8 @@ cv_ver = "" keras_ver = ">=2.0.0" if sys.version_info.major < 3: - cv_ver = "<=4.2.0.32" - keras_ver = "<=2.3.0" + cv_ver = "<=4.2.0.32" + keras_ver = "<=2.3.0" setup(name="keras_segmentation", @@ -19,23 +19,23 @@ url="https://github.com/divamgupta/image-segmentation-keras", packages=find_packages(exclude=["test"]), entry_points={ - 'console_scripts': [ - 'keras_segmentation = keras_segmentation.__main__:main' - ] + 'console_scripts': [ + 'keras_segmentation = keras_segmentation.__main__:main' + ] }, install_requires=[ - "h5py<=2.10.0", - "Keras"+keras_ver, - "imageio==2.5.0", - "imgaug>=0.4.0", - "opencv-python"+cv_ver, - "tqdm"], + "h5py", + "Keras", + "imageio==2.5.0", + "imgaug>=0.4.0", + "opencv-python", + "tqdm"], extras_require={ - # These requires provide different backends available with Keras - "tensorflow": ["tensorflow"], - "cntk": ["cntk"], - "theano": ["theano"], - # Default testing with tensorflow - "tests-default": ["tensorflow", "pytest"] + # These requires provide different backends available with Keras + "tensorflow": ["tensorflow"], + "cntk": ["cntk"], + "theano": ["theano"], + # Default testing with tensorflow + "tests-default": ["tensorflow", "pytest"] } -) + )