diff --git a/.gitignore b/.gitignore index e44c403..5792ba3 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,8 @@ build/ .vscode/ +yolov8x-seg.pt + # Ignore notebook uses for testing temp.ipynb inspect_localize.ipynb diff --git a/Dockerfile b/Dockerfile index fe111cc..d3d9d6d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -94,6 +94,7 @@ RUN python3 -m pip install --upgrade pip RUN pip install flask flask-cors ffmpeg-python RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 RUN pip install nerfstudio +RUN pip install ultralytics RUN mkdir /dependencies COPY ./third_party/hloc/requirements.txt /dependencies/requirements.txt diff --git a/spatial_server/hloc_localization/map_creation/map_creator.py b/spatial_server/hloc_localization/map_creation/map_creator.py index c2333b5..6ac4735 100644 --- a/spatial_server/hloc_localization/map_creation/map_creator.py +++ b/spatial_server/hloc_localization/map_creation/map_creator.py @@ -19,9 +19,7 @@ from .. import config, load_cache from spatial_server.server import shared_data -from spatial_server.utils.run_command import run_command -from spatial_server.utils.print_log import print_log -from . import map_aligner, map_cleaner, kiri_engine, polycam, video +from . import map_aligner, map_cleaner, mask_objects, kiri_engine, polycam def create_map_from_colmap_data( @@ -55,6 +53,10 @@ def create_map_from_colmap_data( hloc_output_dir / "sfm_reconstruction" ) # Path to reconstructed SfM + # Remove masked 3D points from the reconstruction + print("Removing 3D points corresponding to masked (frequently moving) objects..") + mask_objects.remove_masked_points3d(colmap_model_path, image_dir) + # Feature extraction ## Extract local features in each data set image using Superpoint print("Extracting local features using Superpoint..") @@ -63,6 +65,10 @@ def create_map_from_colmap_data( conf=local_feature_conf, image_dir=image_dir, export_dir=hloc_output_dir ) + # Remove masked keypoints from local features database + print("Removing masked local features...") + mask_objects.remove_masked_keypoints(colmap_model_path, local_features_path, image_dir) + print("Extracting global descriptors using NetVLad..") ## Extract global descriptors from each image using NetVLad global_descriptor_conf = extract_features.confs[config.GLOBAL_DESCRIPTOR_EXTRACTOR] diff --git a/spatial_server/hloc_localization/map_creation/mask_objects.py b/spatial_server/hloc_localization/map_creation/mask_objects.py new file mode 100644 index 0000000..7fe2574 --- /dev/null +++ b/spatial_server/hloc_localization/map_creation/mask_objects.py @@ -0,0 +1,129 @@ +import argparse +import os +from pathlib import Path + +from ultralytics import YOLO +import numpy as np +import cv2 +import pycolmap +import h5py + +from ..scale_adjustment import read_write_model + + +# COCO class IDs to be extracted +TARGET_CLASS_IDS = [0, 1, 2, 3, 5, 7, 14, 15, 16, 24, 25, 26, 28, 36, 39, 40, 41, 42, 43, 44, 45, 56, 63, 64, 65, 66, 67] + +# Get relevant masks from segmentation model prediction +# Returns mask (tuple of (class id, mask)) and union_mask (combined mask) +def extract_masks(results, target_class_ids=TARGET_CLASS_IDS): + masks = [] + union_mask = 0 + for res in results: + if hasattr(res, 'masks'): + for i, cls in enumerate(res.boxes.cls): + if int(cls) in target_class_ids: + mask = res.masks.data[i].cpu().numpy() + masks.append((int(cls), mask)) + + # Bitwise OR operation to get the union of all masks so far + union_mask = np.bitwise_or(union_mask, mask.astype(np.uint8)) + return masks, union_mask + + +def remove_masked_keypoints(model_path, features_path, image_dir): + seg_model = YOLO('yolov8x-seg.pt') + cameras, images, points3D = read_write_model.read_model(model_path) + + with h5py.File(features_path, 'r+') as f: + for image_id, image in images.items(): + image_name = image.name + image_path = os.path.join(image_dir, image_name) + + img = cv2.imread(image_path) + height, width = np.shape(img)[:2] + + seg_result = seg_model.predict(source=image_path, conf=0.40) + masks, union_mask = extract_masks(seg_result) + if len(masks) == 0: continue + + resized_mask = cv2.resize(union_mask, (width, height), interpolation=cv2.INTER_NEAREST) + + if image_name in f: + grp = f[image_name] + keypoints = grp['keypoints'][:] + descriptors = grp['descriptors'][:] + scores = grp['scores'][:] + + # Filter out masked keypoints + valid_keypoints = [] + valid_descriptors = [] + valid_scores = [] + for i, (x, y) in enumerate(keypoints): + if not (0 <= x < width and 0 <= y < height and resized_mask[int(np.round(y)), int(np.round(x))]): + valid_keypoints.append([x, y]) + valid_descriptors.append(descriptors[:, i]) + valid_scores.append(scores[i]) + + valid_keypoints = np.array(valid_keypoints) + valid_descriptors = np.array(valid_descriptors).T + valid_scores = np.array(valid_scores) + + # Update the .h5 file + del grp['keypoints'] + del grp['descriptors'] + del grp['scores'] + grp.create_dataset('keypoints', data=valid_keypoints) + grp.create_dataset('descriptors', data=valid_descriptors) + grp.create_dataset('scores', data=valid_scores) + + +# Find and remove masked 3D points from the reconstruction +def remove_masked_points3d(model_path, image_dir, output_path=None): + seg_model = YOLO('yolov8x-seg.pt') + cameras, images, points3D = read_write_model.read_model(model_path) + + point3D_ids_to_mask = set() + + # Iterate though all images in the reconstruction + for image_id, image in images.items(): + image_path = os.path.join(image_dir, image.name) + img = cv2.imread(image_path) + (height, width) = np.shape(img)[:2] + + # Get mask using YOLO segmentation + seg_result = seg_model.predict(source=image_path, conf=0.40) + masks, union_mask = extract_masks(seg_result) + if len(masks) == 0: continue # skip to next iteration if no masks found + + resized_mask = cv2.resize(union_mask, (width, height), interpolation=cv2.INTER_NEAREST) + + # Find 3D points that correspond to 2D points behind mask + for point2D_idx, (x, y) in enumerate(image.xys): + if 0 <= x < (width - 1) and 0 <= y < (height - 1) and resized_mask[int(np.round(y)), int(np.round(x))]: + point3D_id = image.point3D_ids[point2D_idx] + if point3D_id != -1: + point3D_ids_to_mask.add(point3D_id) + + point3D_ids_to_mask = list(point3D_ids_to_mask) + + # Delete 3D points from the reconstruction and all 2D correspondences in images + reconstruction = pycolmap.Reconstruction(model_path) + for id in point3D_ids_to_mask: + reconstruction.delete_point3D(id) + + if output_path is None: + output_path = model_path + if not os.path.exists(output_path): + os.mkdir(output_path) + reconstruction.write(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Remove 3D points in the map corresponding to masked (frequently moving) objects.') + parser.add_argument("--model_path", type=str, help='The path to the COLMAP model file') + parser.add_argument('--image_dir', type=str, help='The path to the image directory') + parser.add_argument('--output_path', type=str, help='The path to the output destination', default=None) + args = parser.parse_args() + remove_masked_points3d(args.model_path, args.image_dir, args.output_path) + \ No newline at end of file diff --git a/third_party/hloc b/third_party/hloc index 475876c..9b69e4b 160000 --- a/third_party/hloc +++ b/third_party/hloc @@ -1 +1 @@ -Subproject commit 475876c08d2523abd89ec6eb35aee7781f7a6f3b +Subproject commit 9b69e4b1a22967538b4b5c6fd642dd0606f84e90