diff --git a/data_processing/README.md b/data_processing/README.md index 99bdc233..36c70c10 100644 --- a/data_processing/README.md +++ b/data_processing/README.md @@ -13,13 +13,14 @@ We also provide a HuggingFace dataset which contains the [pre-computed metadata] 5. ✅ [ETH3D](https://www.eth3d.net/) 6. ✅ [Mapillary Planet Scale Depth & Reconstructions](https://www.mapillary.com/dataset/depth) (MPSD) 7. ✅ [MegaDepth (including Tanks & Temples)](https://www.cs.cornell.edu/projects/megadepth/) -8. ✅ [MVS-Synth](https://phuang17.github.io/DeepMVS/mvs-synth.html) -9. ✅ [Parallel Domain 4D](https://gcd.cs.columbia.edu/#datasets) -10. ✅ [SAIL-VOS 3D](https://sailvos.web.illinois.edu/_site/_site/index.html) -11. ✅ [ScanNet++ v2](https://kaldir.vc.in.tum.de/scannetpp/) -12. ✅ [Spring](https://spring-benchmark.org/) -13. ✅ [TartanAirV2 Wide Baseline](https://uniflowmatch.github.io/) -14. ✅ [UnrealStereo4K](https://github.com/fabiotosi92/SMD-Nets) +8. ✅ [AerialMegaDepth](https://aerial-megadepth.github.io/) +9. ✅ [MVS-Synth](https://phuang17.github.io/DeepMVS/mvs-synth.html) +10. ✅ [Parallel Domain 4D](https://gcd.cs.columbia.edu/#datasets) +11. ✅ [SAIL-VOS 3D](https://sailvos.web.illinois.edu/_site/_site/index.html) +12. ✅ [ScanNet++ v2](https://kaldir.vc.in.tum.de/scannetpp/) +13. ✅ [Spring](https://spring-benchmark.org/) +14. ✅ [TartanAirV2 Wide Baseline](https://uniflowmatch.github.io/) +15. ✅ [UnrealStereo4K](https://github.com/fabiotosi92/SMD-Nets) ## Download Instructions: diff --git a/data_processing/aggregate_scene_names.py b/data_processing/aggregate_scene_names.py index d2a97630..e15c1a96 100644 --- a/data_processing/aggregate_scene_names.py +++ b/data_processing/aggregate_scene_names.py @@ -23,6 +23,7 @@ DL3DV10KSplits, ETH3DSplits, MegaDepthSplits, + AerialMegaDepthSplits, MPSDSplits, ScanNetPPSplits, SpringSplits, @@ -363,6 +364,24 @@ def aggregate(self): super().aggregate(val_split_scenes=self.dataset_split_info.val_split_scenes) +class AerialMegaDepthAggregator(DatasetAggregator): + """Aggregator for AerialMegaDepth dataset.""" + + def __init__(self, root_dir, output_dir, covisibility_version_key="v0"): + super().__init__( + dataset_name="aerialmegadepth", + root_dir=root_dir, + output_dir=output_dir, + covisibility_version_key=covisibility_version_key, + depth_folder="depth", + ) + self.dataset_split_info = AerialMegaDepthSplits() + + def aggregate(self): + """Aggregate the AerialMegaDepth dataset.""" + super().aggregate(val_split_scenes=self.dataset_split_info.val_split_scenes) + + class MPSDAggregator(DatasetAggregator): """Aggregator for MPSD dataset.""" @@ -551,6 +570,7 @@ def main(): "dynamicreplica", "eth3d", "megadepth", + "aerialmegadepth", "mpsd", "mvs_synth", "paralleldomain4d", @@ -567,6 +587,7 @@ def main(): "dynamicreplica", "eth3d", "megadepth", + "aerialmegadepth", "mpsd", "mvs_synth", "paralleldomain4d", @@ -637,6 +658,12 @@ def main(): root_dir=root_dir, output_dir=args.output_dir ) aggregator.aggregate() + elif dataset == "aerialmegadepth": + # AerialMegaDepth + aggregator = AerialMegaDepthAggregator( + root_dir=root_dir, output_dir=args.output_dir + ) + aggregator.aggregate() elif dataset == "mpsd": # MPSD raw_data_root_dir = os.path.join(args.raw_data_root, "mpsd") diff --git a/data_processing/viz_data.py b/data_processing/viz_data.py index 889171c6..6a7afc6f 100644 --- a/data_processing/viz_data.py +++ b/data_processing/viz_data.py @@ -202,6 +202,16 @@ def get_dataset_config(dataset_type): "confidence_key": None, "confidence_thres": 0.0, }, + "aerialmegadepth": { + "root_dir": "/fsx/xrtech/data/aerialmegadepth", + "scene": "0001", + "depth_key": "depth", + "local_frame": False, + "viz_string": "WAI_Viz", + "load_skymask": False, + "confidence_key": None, + "confidence_thres": 0.0, + }, "spring": { "root_dir": "/fsx/xrtech/data/spring", "scene": "0004", @@ -331,6 +341,7 @@ def get_parser(): "blendedmvs", "eth3d", "megadepth", + "aerialmegadepth", "spring", "mpsd", "ase", diff --git a/data_processing/wai_processing/configs/conversion/aerialmegadepth.yaml b/data_processing/wai_processing/configs/conversion/aerialmegadepth.yaml new file mode 100644 index 00000000..e74dee25 --- /dev/null +++ b/data_processing/wai_processing/configs/conversion/aerialmegadepth.yaml @@ -0,0 +1,9 @@ +original_root: # path of raw downloaded dataset +root: # path of wai-formatted dataset + +dataset_name: aerialmegadepth +version: 0.1 +overwrite: True + +scene_filters: + - process_state_not: [conversion, finished] diff --git a/data_processing/wai_processing/configs/launch/aerialmegadepth.yaml b/data_processing/wai_processing/configs/launch/aerialmegadepth.yaml new file mode 100644 index 00000000..bcf8e51b --- /dev/null +++ b/data_processing/wai_processing/configs/launch/aerialmegadepth.yaml @@ -0,0 +1,24 @@ +stage: # set stage via CLI +root: # path of wai-formatted dataset + +gpus: 0 +cpus: 10 +mem: 20 +scenes_per_job: 20 +conda_env: # pass the name of our conda environment +nodelist: + +stages: + conversion: + script: conversion/aerialmegadepth.py + config: conversion/aerialmegadepth.yaml + scenes_per_job: 20 # fast + covisibility: + script: covisibility.py + config: covisibility/covisibility_gt_depth_224x224.yaml + gpus: 1 + moge: + script: run_moge.py + config: moge/default.yaml + additional_cli_params: ['batch_size=1'] + gpus: 1 diff --git a/data_processing/wai_processing/download_scripts/README.md b/data_processing/wai_processing/download_scripts/README.md index f788354b..b9f147e9 100644 --- a/data_processing/wai_processing/download_scripts/README.md +++ b/data_processing/wai_processing/download_scripts/README.md @@ -28,6 +28,10 @@ Use the provided bash script `download_mpsd.sh` and unzip the downloaded zip fil Use the provided python script `download_megadepth.py`. MegaDepth also includes all the Tanks & Temples scenes. **Source:** [MegaDepth Project](https://www.cs.cornell.edu/projects/megadepth/) +## AerialMegaDepth +Use the provided python script `download_aerialmegadepth.py`. +**Source:** [AerialMegaDepth](https://aerial-megadepth.github.io/) + ## MVS-Synth Use the provided python script `download_mvs_synth.py`. **Source:** [MVS-Synth Dataset](https://phuang17.github.io/DeepMVS/mvs-synth.html) diff --git a/data_processing/wai_processing/download_scripts/download_aerialmegadepth.py b/data_processing/wai_processing/download_scripts/download_aerialmegadepth.py new file mode 100644 index 00000000..31f6613d --- /dev/null +++ b/data_processing/wai_processing/download_scripts/download_aerialmegadepth.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +Download AerialMegaDepth dataset from HuggingFace. + +References: https://github.com/kvuong2711/aerial-megadepth/blob/main/data_generation/download_data_hf.py +""" + +from __future__ import annotations + +import argparse +import shutil +from pathlib import Path +from huggingface_hub import snapshot_download +from wai_processing.utils.download import ( + extract_zip_archives, +) + + +# Configuration for AerialMegaDepth dataset +REPO_ID = "kvuong2711/aerialmegadepth" +ALLOW_PATTERNS = ("**.zip", "aerial_megadepth_all.npz") +DEFAULT_MAX_WORKERS = 8 +ZIP_DIR_NAME = "aerialmegadepth_zip" +EXTRACT_DIR_NAME = "aerialmegadepth" + + +def download_archives(zip_dir: Path, max_workers: int): + """Download dataset archives into ``zip_dir``.""" + zip_dir.mkdir(parents=True, exist_ok=True) + + print(f"Downloading {REPO_ID} archives to {zip_dir}...") + snapshot_download( + repo_id=REPO_ID, + repo_type="dataset", + local_dir=str(zip_dir), + max_workers=max_workers, + allow_patterns=list(ALLOW_PATTERNS), + ) + print(f"Download complete!") + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Download the Aerial MegaDepth dataset and optionally extract archives.", + ) + parser.add_argument( + "--target_dir", + type=str, + required=True, + help="Base directory for downloaded archives and extracted data.", + ) + parser.add_argument( + "--max_workers", + type=int, + default=DEFAULT_MAX_WORKERS, + help="Number of parallel workers used by the Hugging Face downloader.", + ) + return parser + + +def main(): + parser = build_parser() + args = parser.parse_args() + + target_dir = Path(args.target_dir) + zip_dir = target_dir / ZIP_DIR_NAME + extract_dir = target_dir / EXTRACT_DIR_NAME + + # 1. Download zip files from huggingface + download_archives(zip_dir, max_workers=args.max_workers) + + # 2. Extract zip files + extract_zip_archives(target_dir=zip_dir, output_dir=extract_dir, n_workers=args.max_workers) + + # 3. Move the aerial_megadepth_all.npz to the extract_dir + shutil.move(zip_dir / "aerial_megadepth_all.npz", extract_dir / "aerial_megadepth_all.npz") + + print("All tasks completed successfully.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/data_processing/wai_processing/scripts/conversion/aerialmegadepth.py b/data_processing/wai_processing/scripts/conversion/aerialmegadepth.py new file mode 100644 index 00000000..e39de2f2 --- /dev/null +++ b/data_processing/wai_processing/scripts/conversion/aerialmegadepth.py @@ -0,0 +1,356 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. +""" +Converts AerialMegaDepth dataset to WAI format. + +Reference: https://github.com/kvuong2711/aerial-megadepth/blob/main/data_generation/datasets_preprocess/preprocess_aerialmegadepth.py +""" +import logging +import os +from pathlib import Path +import cv2 +import h5py +import numpy as np +import torch +from argconf import argconf_parse +from natsort import natsorted +from tqdm import tqdm +from wai_processing.utils.globals import WAI_PROC_CONFIG_PATH +from wai_processing.utils.wrapper import convert_scenes_wrapper + +from mapanything.utils.wai.core import store_data +from mapanything.utils.wai.scene_frame import _filter_scenes +import shutil + + +logger = logging.getLogger(__name__) + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" + + +def _load_kpts_and_poses(root, scene_name, z_only=False, intrinsics=True): + """ + Load camera parameters from MegaDepth dataset. + + Args: + root: Root directory of the MegaDepth dataset + scene_name: Scene Name + z_only: If True, only return the principal axis + intrinsics: If True, also return camera intrinsics + + Returns: + points3D_idxs: Dictionary mapping image IDs to 3D point indices + poses: Dictionary mapping image IDs to camera poses + image_intrinsics: Dictionary mapping image IDs to camera intrinsics (if intrinsics=True) + """ + if intrinsics: + with open( + os.path.join( + root, scene_name, "sfm_output_localization", "sfm_superpoint+superglue", "localized_dense_metric", "sparse-txt", "cameras.txt" + ), + "r", + ) as f: + raw = f.readlines()[3:] # skip the header + + camera_intrinsics = {} + for camera in raw: + camera = camera.split(" ") + width, height, focal, cx, cy = [float(elem) for elem in camera[2:]] + K = np.eye(3) + K[0, 0] = focal + K[1, 1] = focal + K[0, 2] = cx + K[1, 2] = cy + camera_intrinsics[int(camera[0])] = ( + (int(width), int(height)), + K, + (0, 0, 0, 0), + ) + + with open( + os.path.join(root, scene_name, "sfm_output_localization", "sfm_superpoint+superglue", "localized_dense_metric", "sparse-txt", "images.txt"), + "r", + ) as f: + raw = f.read().splitlines()[4:] # skip the header + + extract_pose = ( + colmap_raw_pose_to_principal_axis if z_only else colmap_raw_pose_to_RT + ) + + poses = {} + points3D_idxs = {} + camera = [] + + for image, points in zip(raw[::2], raw[1::2]): + image = image.split(" ") + points = points.split(" ") + + image_id = image[-1] + camera.append(int(image[-2])) + + # find the principal axis + raw_pose = [float(elem) for elem in image[1:-2]] + poses[image_id] = extract_pose(raw_pose) + + current_points3D_idxs = {int(i) for i in points[2::3] if i != "-1"} + assert -1 not in current_points3D_idxs + points3D_idxs[image_id] = current_points3D_idxs + + if intrinsics: + image_intrinsics = { + im_id: camera_intrinsics[cam] for im_id, cam in zip(poses, camera) + } + return points3D_idxs, poses, image_intrinsics + else: + return points3D_idxs, poses + + +def colmap_raw_pose_to_principal_axis(image_pose): + """Convert COLMAP quaternion to principal axis.""" + qvec = image_pose[:4] + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + z_axis = np.float32( + [2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w, 1 - 2 * x * x - 2 * y * y] + ) + return z_axis + + +def colmap_raw_pose_to_RT(image_pose): + """Convert COLMAP quaternion to rotation matrix and translation vector.""" + qvec = image_pose[:4] + qvec = qvec / np.linalg.norm(qvec) + w, x, y, z = qvec + R = np.array( + [ + [1 - 2 * y * y - 2 * z * z, 2 * x * y - 2 * z * w, 2 * x * z + 2 * y * w], + [2 * x * y + 2 * z * w, 1 - 2 * x * x - 2 * z * z, 2 * y * z - 2 * x * w], + [2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w, 1 - 2 * x * x - 2 * y * y], + ] + ) + t = image_pose[4:7] + # World-to-Camera pose + current_pose = np.eye(4) + current_pose[:3, :3] = R + current_pose[:3, 3] = t + return current_pose + + +def process_aerialmegadepth_scene(cfg, scene_name): + """ + Process a AerialMegaDepth scene into the WAI format. + Convert the H5 format depth maps to default WAI depth format (exr). + Load the (already undistorted) intrinsics and save the images, depth maps, intrinsics and poses to the WAI format. + Only processes images that are in the aerial_megadepth_all.npz file. + + Expected root directory structure for the raw AerialMegaDepth dataset: + . + └── aerialmegadepth/ + ├── 0001/ + │ ├── sfm_output_localization/sfm_superpoint+superglue/localized_dense_metric/ + │ │ ├── depths/ + │ │ | ├── 2910735886_d62fbf91c9_o.jpg.h5 + │ │ | ├── ... + │ │ ├── images/ + │ │ | ├── 2910735886_d62fbf91c9_o.jpg.jpg + │ │ | ├── ... + │ │ ├── sparse-txt/ + │ │ | ├── cameras.txt + │ │ | ├── images.txt + │ │ | ├── points3D.txt + ├── ... + ├── aerial_megadepth_all.npz + """ + + # Scene output path + scene_outpath = Path(cfg.root) / scene_name + scene_outpath.mkdir(parents=True, exist_ok=True) + + # Create target directories for this scene + target_scene_root = Path(cfg.root) / scene_name + image_dir = target_scene_root / "images" + image_dir.mkdir(parents=True, exist_ok=False) + depth_dir = target_scene_root / "depth" + depth_dir.mkdir(parents=True, exist_ok=False) + + # Initialize frames list for this scene + wai_frames = [] + + # Load camera parameters + _, pose_w2cam, intrinsics = _load_kpts_and_poses( + cfg.original_root, scene_name, intrinsics=True + ) + + # Get the scene path and dense directory for this subscene + scene_path = Path(cfg.original_root) / scene_name + dense_dir = scene_path / "sfm_output_localization" / "sfm_superpoint+superglue" / "localized_dense_metric" + + # Load megadepth_pairs.npz to filter images + pairs_path = Path(cfg.original_root) / "aerial_megadepth_all.npz" + if not pairs_path.exists(): + raise FileNotFoundError( + f"aerial_megadepth_all.npz not found at {pairs_path}. Cannot proceed without pairs file." + ) + + # Load pairs data + data = np.load(pairs_path, allow_pickle=True) + images = data["images"] + images_scene_name = data['images_scene_name'] + + # Find images for this scene + images_to_process = set() + scene_found = False + + # Current scene identifier + current_scene = f"{scene_name}" + + # Collect all images for this scene from the pairs + for image_idx, image_id in enumerate(images): + if image_id != None: + scene = images_scene_name[image_idx] + # Check if this pair belongs to our scene + if isinstance(scene, str) and scene == current_scene: + scene_found = True + images_to_process.add(image_id) + + if not scene_found: + logger.warning( + f"Scene {scene_name} not found in pairs file. Skipping this scene." + ) + return "skipped", f"Scene {scene_name} not found in pairs file" + + logger.info( + f"Found {len(images_to_process)} images to process for scene {scene_name}" + ) + + # Segmentation masks + segmasks_dir = Path(cfg.original_root) / "aerialmegadepth_segmasks" / scene_name + + # Process each image in the subscene in natural sorted order + for image_id in tqdm(natsorted(images_to_process)): + # Get intrinsic data for this image + intrinsic_data = intrinsics[image_id] + + # Get image filename + img_path = dense_dir / "images" / image_id + + # Skip if image doesn't exist + # if not img_path.exists(): + # continue + assert img_path.exists() + + # Check if depth file exists + depth_filename = Path(image_id).stem + ".h5" + depth_path = dense_dir / "depths" / depth_filename + + # if not depth_path.exists(): + # continue + assert depth_path.exists() + + # Symlink original image to WAI path + rel_target_image_path = Path("images") / image_id + os.symlink(img_path, target_scene_root / rel_target_image_path) + # shutil.copy(img_path, target_scene_root / rel_target_image_path) + + # Load depth map from H5 file + with h5py.File(depth_path, "r") as hd5: + depthmap = np.asarray(hd5["depth"]) + + # Get the dimensions of the depth map + H, W = depthmap.shape + + # Load segmentation map to filter out invalid depth values at sky regions + segmask_path = segmasks_dir / (image_id + ".png") + assert segmask_path.exists(), f"Segmentation mask not found at {segmask_path}" + segmask = cv2.imread(str(segmask_path))[:, :, 0] + depthmap[segmask == 2] = 0 # Remove the sky from the depthmap (ADE20K) + + # Save depth map to EXR file using WAI + rel_depth_out_path = Path("depth") / (Path(image_id).stem + ".exr") + store_data( + target_scene_root / rel_depth_out_path, + torch.tensor(depthmap), + "depth", + ) + + # Get intrinsics + imsize_pre, K_pre, distortion = intrinsic_data + + # Since we don't do any undistortion, the post-undistortion intrinsics are the same as the pre-undistortion intrinsics + K_post = K_pre + + # Get camera pose (world to camera) + w2cam_pose = pose_w2cam[image_id] + + # Convert to camera to world pose + cam2world_pose = np.linalg.inv(w2cam_pose) + + # Store WAI frame metadata + wai_frame = { + "frame_name": Path(image_id).stem, + "image": str(rel_target_image_path), + "file_path": str(rel_target_image_path), + "depth": str(rel_depth_out_path), + "transform_matrix": cam2world_pose.tolist(), + "h": H, + "w": W, + "fl_x": float(K_post[0, 0]), + "fl_y": float(K_post[1, 1]), + "cx": float(K_post[0, 2]), + "cy": float(K_post[1, 2]), + } + wai_frames.append(wai_frame) + + # Construct scene metadata for this subscene + scene_meta = { + "scene_name": scene_name, + "dataset_name": cfg.dataset_name, + "version": cfg.version, + "shared_intrinsics": False, + "camera_model": "PINHOLE", + "camera_convention": "opencv", + "scale_type": "colmap", + "scene_modalities": {}, + "frames": wai_frames, + "frame_modalities": { + "image": {"frame_key": "image", "format": "image"}, + "depth": { + "frame_key": "depth", + "format": "depth", + }, + }, + } + store_data(target_scene_root / "scene_meta.json", scene_meta, "scene_meta") + + +def get_original_scene_names( + cfg, +): + # Get all scene names to process + original_scene_names = sorted(os.listdir(cfg.original_root)) + + # Create a list of all scene_subscene combinations + all_scene_names = [] + # First pass: collect all subscenes for each scene + for scene_name in original_scene_names: + scene_path = Path(cfg.original_root) / scene_name + if scene_path.is_dir() and 'sfm_output_localization' in os.listdir(scene_path): + all_scene_names.append(scene_name) + # scene filter for batch processing + all_scene_names = _filter_scenes( + cfg.root, all_scene_names, cfg.get("scene_filters") + ) + return all_scene_names + + +if __name__ == "__main__": + cfg = argconf_parse(WAI_PROC_CONFIG_PATH / "conversion/aerialmegadepth.yaml") + target_root_dir = Path(cfg.root) + target_root_dir.mkdir(parents=True, exist_ok=True) + convert_scenes_wrapper( + process_aerialmegadepth_scene, + cfg, + get_original_scene_names_func=get_original_scene_names, + ) diff --git a/mapanything/datasets/__init__.py b/mapanything/datasets/__init__.py index b01f8285..346f1f3e 100644 --- a/mapanything/datasets/__init__.py +++ b/mapanything/datasets/__init__.py @@ -15,6 +15,7 @@ from mapanything.datasets.wai.dynamicreplica import DynamicReplicaWAI # noqa from mapanything.datasets.wai.eth3d import ETH3DWAI # noqa from mapanything.datasets.wai.megadepth import MegaDepthWAI # noqa +from mapanything.datasets.wai.aerialmegadepth import AerialMegaDepthWAI # noqa from mapanything.datasets.wai.mpsd import MPSDWAI # noqa from mapanything.datasets.wai.mvs_synth import MVSSynthWAI # noqa from mapanything.datasets.wai.paralleldomain4d import ParallelDomain4DWAI # noqa diff --git a/mapanything/datasets/utils/data_splits.py b/mapanything/datasets/utils/data_splits.py index d9103e90..e7b3ba36 100644 --- a/mapanything/datasets/utils/data_splits.py +++ b/mapanything/datasets/utils/data_splits.py @@ -1591,6 +1591,18 @@ def __init__(self): self.val_split_scenes = ["0015_0", "0015_1", "0022_0"] +class AerialMegaDepthSplits: + """ + This class contains the information about the splits of the AerialMegaDepth dataset. + """ + + def __init__(self): + """ + Validation split is based on scenes used in DUSt3R. + """ + self.val_split_scenes = ["0015", "0015", "0022"] + + class SpringSplits: """ This class contains the information about the splits of the Spring dataset. diff --git a/mapanything/datasets/wai/aerialmegadepth.py b/mapanything/datasets/wai/aerialmegadepth.py new file mode 100644 index 00000000..034b2953 --- /dev/null +++ b/mapanything/datasets/wai/aerialmegadepth.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +""" +AerialMegaDepth Dataset using WAI format data. +""" + +import os + +import cv2 +import numpy as np + +from mapanything.datasets.base.base_dataset import BaseDataset +from mapanything.utils.wai.core import load_data, load_frame + + +class AerialMegaDepthWAI(BaseDataset): + """ + AerialMegaDepth dataset containing outdoor phototourism and in-the-wild scenes. + """ + + def __init__( + self, + *args, + ROOT, + dataset_metadata_dir, + split, + overfit_num_sets=None, + sample_specific_scene: bool = False, + specific_scene_name: str = None, + **kwargs, + ): + """ + Initialize the dataset attributes. + Args: + ROOT: Root directory of the dataset. + dataset_metadata_dir: Path to the dataset metadata directory. + split: Dataset split (train, val, test). + overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets. + sample_specific_scene: Whether to sample a specific scene from the dataset. + specific_scene_name: Name of the specific scene to sample. + """ + # Initialize the dataset attributes + super().__init__(*args, **kwargs) + self.ROOT = ROOT + self.dataset_metadata_dir = dataset_metadata_dir + self.split = split + self.overfit_num_sets = overfit_num_sets + self.sample_specific_scene = sample_specific_scene + self.specific_scene_name = specific_scene_name + self._load_data() + + # Define the dataset type flags + self.is_metric_scale = True + self.is_synthetic = False + + def _load_data(self): + "Load the precomputed dataset metadata" + # Load the dataset metadata corresponding to the split + split_metadata_path = os.path.join( + self.dataset_metadata_dir, + self.split, + f"aerialmegadepth_scene_list_{self.split}.npy", + ) + split_scene_list = np.load(split_metadata_path, allow_pickle=True) + + # Get the list of all scenes + if not self.sample_specific_scene: + self.scenes = list(split_scene_list) + else: + self.scenes = [self.specific_scene_name] + self.num_of_scenes = len(self.scenes) + + def _get_views(self, sampled_idx, num_views_to_sample, resolution): + # Get the scene name of the sampled index + scene_index = sampled_idx + scene_name = self.scenes[scene_index] + + # Get the metadata corresponding to the scene + scene_root = os.path.join(self.ROOT, scene_name) + scene_meta = load_data( + os.path.join(scene_root, "scene_meta.json"), "scene_meta" + ) + scene_file_names = list(scene_meta["frame_names"].keys()) + num_views_in_scene = len(scene_file_names) + + # Load the scene pairwise covisibility mmap + covisibility_version_key = "v0" + covisibility_map_dir = os.path.join( + scene_root, "covisibility", covisibility_version_key + ) + # Assumes only npy file in directory is covisbility map + covisibility_map_name = next( + f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy") + ) + covisibility_map_path = os.path.join( + scene_root, "covisibility", covisibility_version_key, covisibility_map_name + ) + pairwise_covisibility = load_data(covisibility_map_path, "mmap") + + # Get the indices of the N views in the scene + view_indices = self._sample_view_indices( + num_views_to_sample, num_views_in_scene, pairwise_covisibility + ) + + # Get the views corresponding to the selected view indices + views = [] + for view_index in view_indices: + # Load the data corresponding to the view + view_file_name = scene_file_names[view_index] + view_data = load_frame( + scene_root, + view_file_name, + modalities=["image", "depth"], # "pred_mask/moge2"], + scene_meta=scene_meta, + ) + + # Convert necessary data to numpy + image = view_data["image"].permute(1, 2, 0).numpy() + image = (image * 255).astype(np.uint8) + depthmap = view_data["depth"].numpy().astype(np.float32) + intrinsics = view_data["intrinsics"].numpy().astype(np.float32) + c2w_pose = view_data["extrinsics"].numpy().astype(np.float32) + + # Ensure that the depthmap has all valid values + depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0) + + # Get the non_ambiguous_mask and ensure it matches image resolution + non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int) + non_ambiguous_mask = cv2.resize( + non_ambiguous_mask, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + # Mask out the GT depth using the non_ambiguous_mask + depthmap = np.where(non_ambiguous_mask, depthmap, 0) + + # Resize the data to match the desired resolution + additional_quantities_to_resize = [non_ambiguous_mask] + image, depthmap, intrinsics, additional_quantities_to_resize = ( + self._crop_resize_if_necessary( + image=image, + resolution=resolution, + depthmap=depthmap, + intrinsics=intrinsics, + additional_quantities=additional_quantities_to_resize, + ) + ) + non_ambiguous_mask = additional_quantities_to_resize[0] + + # Append the view dictionary to the list of views + views.append( + dict( + img=image, + depthmap=depthmap, + camera_pose=c2w_pose, # cam2world + camera_intrinsics=intrinsics, + non_ambiguous_mask=non_ambiguous_mask, + dataset="AerialMegaDepth", + label=scene_name, + instance=os.path.join("images", str(view_file_name)), + ) + ) + + return views + + +def get_parser(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-rd", "--root_dir", default="/fsx/xrtech/data/aerialmegadepth", type=str + ) + parser.add_argument( + "-dmd", + "--dataset_metadata_dir", + default="/fsx/nkeetha/mapanything_dataset_metadata", + type=str, + ) + parser.add_argument( + "-nv", + "--num_of_views", + default=2, + type=int, + ) + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + import rerun as rr + from tqdm import tqdm + + from mapanything.datasets.base.base_dataset import view_name + from mapanything.utils.image import rgb + from mapanything.utils.viz import script_add_rerun_args + + parser = get_parser() + script_add_rerun_args( + parser + ) # Options: --headless, --connect, --serve, --addr, --save, --stdout + args = parser.parse_args() + + dataset = AerialMegaDepthWAI( + num_views=args.num_of_views, + split="train", + covisibility_thres=0.25, + ROOT=args.root_dir, + dataset_metadata_dir=args.dataset_metadata_dir, + resolution=(518, 336), + aug_crop=16, + transform="colorjitter+grayscale+gaublur", + data_norm_type="dinov2", + ) + # dataset = AerialMegaDepthWAI( + # num_views=args.num_of_views, + # split="val", + # covisibility_thres=0.25, + # ROOT=args.root_dir, + # dataset_metadata_dir=args.dataset_metadata_dir, + # resolution=(518, 336), + # seed=777, + # transform="imgnorm", + # data_norm_type="dinov2", + # ) + print(dataset.get_stats()) + + if args.viz: + rr.script_setup(args, "AerialMegaDepth_Dataloader") + rr.set_time("stable_time", sequence=0) + rr.log("world", rr.ViewCoordinates.RDF, static=True) + + sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False) + + for num, idx in enumerate(tqdm(sampled_indices)): + views = dataset[idx] + assert len(views) == args.num_of_views + sample_name = f"{idx}" + for view_idx in range(args.num_of_views): + sample_name += f" {view_name(views[view_idx])}" + print(sample_name) + for view_idx in range(args.num_of_views): + image = rgb( + views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"] + ) + depthmap = views[view_idx]["depthmap"] + pose = views[view_idx]["camera_pose"] + intrinsics = views[view_idx]["camera_intrinsics"] + pts3d = views[view_idx]["pts3d"] + valid_mask = views[view_idx]["valid_mask"] + if "non_ambiguous_mask" in views[view_idx]: + non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"] + else: + non_ambiguous_mask = None + if "prior_depth_along_ray" in views[view_idx]: + prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"] + else: + prior_depth_along_ray = None + if args.viz: + rr.set_time("stable_time", sequence=num) + base_name = f"world/view_{view_idx}" + pts_name = f"world/view_{view_idx}_pointcloud" + # Log camera info and loaded data + height, width = image.shape[0], image.shape[1] + rr.log( + base_name, + rr.Transform3D( + translation=pose[:3, 3], + mat3x3=pose[:3, :3], + ), + ) + rr.log( + f"{base_name}/pinhole", + rr.Pinhole( + image_from_camera=intrinsics, + height=height, + width=width, + camera_xyz=rr.ViewCoordinates.RDF, + ), + ) + rr.log( + f"{base_name}/pinhole/rgb", + rr.Image(image), + ) + rr.log( + f"{base_name}/pinhole/depth", + rr.DepthImage(depthmap), + ) + if prior_depth_along_ray is not None: + rr.log( + f"prior_depth_along_ray_{view_idx}", + rr.DepthImage(prior_depth_along_ray), + ) + if non_ambiguous_mask is not None: + rr.log( + f"{base_name}/pinhole/non_ambiguous_mask", + rr.SegmentationImage(non_ambiguous_mask.astype(int)), + ) + # Log points in 3D + filtered_pts = pts3d[valid_mask] + filtered_pts_col = image[valid_mask] + rr.log( + pts_name, + rr.Points3D( + positions=filtered_pts.reshape(-1, 3), + colors=filtered_pts_col.reshape(-1, 3), + ), + ) diff --git a/train.md b/train.md index e2519730..8ac5d60f 100644 --- a/train.md +++ b/train.md @@ -105,7 +105,7 @@ The scripts include optimized settings for AWS multi-node training with EFA netw ## Dataset Coverage -The training scripts support all 13 training datasets (with appropriate splits) converted to WAI format: +The training scripts support all 14 training datasets (with appropriate splits) converted to WAI format: 1. ✅ [Aria Synthetic Environments](https://www.projectaria.com/datasets/ase/) 2. ✅ [BlendedMVS](https://github.com/YoYo000/BlendedMVS) @@ -113,13 +113,14 @@ The training scripts support all 13 training datasets (with appropriate splits) 4. ✅ [Dynamic Replica](https://dynamic-stereo.github.io/) 5. ✅ [Mapillary Planet Scale Depth & Reconstructions](https://www.mapillary.com/dataset/depth) (MPSD) 6. ✅ [MegaDepth (including Tanks & Temples)](https://www.cs.cornell.edu/projects/megadepth/) -7. ✅ [MVS-Synth](https://phuang17.github.io/DeepMVS/mvs-synth.html) -8. ✅ [Parallel Domain 4D](https://gcd.cs.columbia.edu/#datasets) -9. ✅ [SAIL-VOS 3D](https://sailvos.web.illinois.edu/_site/_site/index.html) -10. ✅ [ScanNet++ v2](https://kaldir.vc.in.tum.de/scannetpp/) -11. ✅ [Spring](https://spring-benchmark.org/) -12. ✅ [TartanAirV2 Wide Baseline](https://uniflowmatch.github.io/) -13. ✅ [UnrealStereo4K](https://github.com/fabiotosi92/SMD-Nets) +7. ✅ [AerialMegaDepth](https://aerial-megadepth.github.io/) +8. ✅ [MVS-Synth](https://phuang17.github.io/DeepMVS/mvs-synth.html) +9. ✅ [Parallel Domain 4D](https://gcd.cs.columbia.edu/#datasets) +10. ✅ [SAIL-VOS 3D](https://sailvos.web.illinois.edu/_site/_site/index.html) +11. ✅ [ScanNet++ v2](https://kaldir.vc.in.tum.de/scannetpp/) +12. ✅ [Spring](https://spring-benchmark.org/) +13. ✅ [TartanAirV2 Wide Baseline](https://uniflowmatch.github.io/) +14. ✅ [UnrealStereo4K](https://github.com/fabiotosi92/SMD-Nets) ## Reproducing Results