From b9272319b1f379f8e0a8cef132cb39341b4d2e39 Mon Sep 17 00:00:00 2001 From: Mehar Khurana Date: Tue, 7 Oct 2025 02:05:53 -0400 Subject: [PATCH 1/2] COLMAP demo: handle partial pose info --- scripts/demo_inference_on_colmap_outputs.py | 47 ++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/scripts/demo_inference_on_colmap_outputs.py b/scripts/demo_inference_on_colmap_outputs.py index 94ddb4e6..60875218 100644 --- a/scripts/demo_inference_on_colmap_outputs.py +++ b/scripts/demo_inference_on_colmap_outputs.py @@ -100,6 +100,14 @@ def load_colmap_data(colmap_path, stride=1, verbose=False, ext=".bin"): views_example = [] processed_count = 0 + # Get a list of all colmap image names + colmap_image_names = set(img_info.name for img_info in images_colmap.values()) + # Find the unposed images (in images/ but not in colmap) + unposed_images = available_images - colmap_image_names + + if verbose: + print(f"Found {len(unposed_images)} images without COLMAP poses") + # Process images in COLMAP order for img_id, img_info in images_colmap.items(): # Apply stride @@ -172,9 +180,46 @@ def load_colmap_data(colmap_path, stride=1, verbose=False, ext=".bin"): print(f"Warning: Failed to load data for {img_name}: {e}") processed_count += 1 continue + + # process unposed images (without COLMAP poses) + for img_name in unposed_images: + # Apply stride + if processed_count % stride != 0: + processed_count += 1 + continue + + image_path = os.path.join(images_folder, img_name) + + try: + # Load image + image = Image.open(image_path).convert("RGB") + image_array = np.array(image).astype(np.uint8) # (H, W, 3) - [0, 255] + + # Convert to tensor + image_tensor = torch.from_numpy(image_array) # (H, W, 3) + + view = { + "img": image_tensor, # (H, W, 3) - [0, 255] + # No intrinsics or pose available + } + + views_example.append(view) + processed_count += 1 + + if verbose: + print( + f"Loaded unposed view {len(views_example) - 1}: {img_name} (shape: {image_array.shape})" + ) + + except Exception as e: + if verbose: + print(f"Warning: Failed to load data for {img_name}: {e}") + processed_count += 1 + continue + if not views_example: - raise ValueError("No valid COLMAP data found") + raise ValueError("No valid images found") if verbose: print(f"Successfully loaded {len(views_example)} views with stride={stride}") From dd3d43914e6a0424e21b5c1a60c040ceebe4c6ca Mon Sep 17 00:00:00 2001 From: Mehar Khurana Date: Fri, 19 Dec 2025 09:44:51 -0500 Subject: [PATCH 2/2] update colmap scripts; fix bugs --- mapanything/third_party/track_predict.py | 353 ----------- mapanything/third_party/vggsfm_tracker.py | 141 ----- mapanything/third_party/vggsfm_utils.py | 340 ----------- scripts/demo_colmap.py | 644 +++++--------------- scripts/demo_inference_on_colmap_outputs.py | 409 +++++++------ 5 files changed, 370 insertions(+), 1517 deletions(-) delete mode 100644 mapanything/third_party/track_predict.py delete mode 100644 mapanything/third_party/vggsfm_tracker.py delete mode 100644 mapanything/third_party/vggsfm_utils.py diff --git a/mapanything/third_party/track_predict.py b/mapanything/third_party/track_predict.py deleted file mode 100644 index 8f653f3e..00000000 --- a/mapanything/third_party/track_predict.py +++ /dev/null @@ -1,353 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -# Modified from https://github.com/facebookresearch/vggt - -import numpy as np -import torch - -from .vggsfm_utils import ( - build_vggsfm_tracker, - calculate_index_mappings, - extract_keypoints, - generate_rank_by_dino, - initialize_feature_extractors, - predict_tracks_in_chunks, - switch_tensor_order, -) - - -def predict_tracks( - images, - conf=None, - points_3d=None, - max_query_pts=2048, - query_frame_num=5, - keypoint_extractor="aliked+sp", - max_points_num=163840, - fine_tracking=True, - complete_non_vis=True, -): - """ - Predict tracks for the given images and masks. - - TODO: support non-square images - TODO: support masks - - - This function predicts the tracks for the given images and masks using the specified query method - and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames. - - Args: - images: Tensor of shape [S, 3, H, W] containing the input images. - conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None. - points_3d: Tensor containing 3D points. Default is None. - max_query_pts: Maximum number of query points. Default is 2048. - query_frame_num: Number of query frames to use. Default is 5. - keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp". - max_points_num: Maximum number of points to process at once. Default is 163840. - fine_tracking: Whether to use fine tracking. Default is True. - complete_non_vis: Whether to augment non-visible frames. Default is True. - - Returns: - pred_tracks: Numpy array containing the predicted tracks. - pred_vis_scores: Numpy array containing the visibility scores for the tracks. - pred_confs: Numpy array containing the confidence scores for the tracks. - pred_points_3d: Numpy array containing the 3D points for the tracks. - pred_colors: Numpy array containing the point colors for the tracks. (0, 255) - """ - - device = images.device - dtype = images.dtype - tracker = build_vggsfm_tracker().to(device, dtype) - - # Find query frames - query_frame_indexes = generate_rank_by_dino( - images, query_frame_num=query_frame_num, device=device - ) - - # Add the first image to the front if not already present - if 0 in query_frame_indexes: - query_frame_indexes.remove(0) - query_frame_indexes = [0, *query_frame_indexes] - - # TODO: add the functionality to handle the masks - keypoint_extractors = initialize_feature_extractors( - max_query_pts, extractor_method=keypoint_extractor, device=device - ) - - pred_tracks = [] - pred_vis_scores = [] - pred_confs = [] - pred_points_3d = [] - pred_colors = [] - - fmaps_for_tracker = tracker.process_images_to_fmaps(images) - - if fine_tracking: - print("For faster inference, consider disabling fine_tracking") - - for query_index in query_frame_indexes: - print(f"Predicting tracks for query frame {query_index}") - pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query( - query_index, - images, - conf, - points_3d, - fmaps_for_tracker, - keypoint_extractors, - tracker, - max_points_num, - fine_tracking, - device, - ) - - pred_tracks.append(pred_track) - pred_vis_scores.append(pred_vis) - pred_confs.append(pred_conf) - pred_points_3d.append(pred_point_3d) - pred_colors.append(pred_color) - - if complete_non_vis: - pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors = ( - _augment_non_visible_frames( - pred_tracks, - pred_vis_scores, - pred_confs, - pred_points_3d, - pred_colors, - images, - conf, - points_3d, - fmaps_for_tracker, - keypoint_extractors, - tracker, - max_points_num, - fine_tracking, - min_vis=500, - non_vis_thresh=0.1, - device=device, - ) - ) - - pred_tracks = np.concatenate(pred_tracks, axis=1) - pred_vis_scores = np.concatenate(pred_vis_scores, axis=1) - pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None - pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None - pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None - - # from vggt.utils.visual_track import visualize_tracks_on_images - # visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals") - - return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors - - -def _forward_on_query( - query_index, - images, - conf, - points_3d, - fmaps_for_tracker, - keypoint_extractors, - tracker, - max_points_num, - fine_tracking, - device, -): - """ - Process a single query frame for track prediction. - - Args: - query_index: Index of the query frame - images: Tensor of shape [S, 3, H, W] containing the input images - conf: Confidence tensor - points_3d: 3D points tensor - fmaps_for_tracker: Feature maps for the tracker - keypoint_extractors: Initialized feature extractors - tracker: VGG-SFM tracker - max_points_num: Maximum number of points to process at once - fine_tracking: Whether to use fine tracking - device: Device to use for computation - - Returns: - pred_track: Predicted tracks - pred_vis: Visibility scores for the tracks - pred_conf: Confidence scores for the tracks - pred_point_3d: 3D points for the tracks - pred_color: Point colors for the tracks (0, 255) - """ - frame_num, _, height, width = images.shape - - query_image = images[query_index] - query_points = extract_keypoints( - query_image, keypoint_extractors, round_keypoints=False - ) - query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)] - - # Extract the color at the keypoint locations - query_points_long = query_points.squeeze(0).round().long() - pred_color = images[query_index][ - :, query_points_long[:, 1], query_points_long[:, 0] - ] - pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8) - - # Query the confidence and points_3d at the keypoint locations - if (conf is not None) and (points_3d is not None): - assert height == width - assert conf.shape[-2] == conf.shape[-1] - assert conf.shape[:3] == points_3d.shape[:3] - scale = conf.shape[-1] / width - - query_points_scaled = (query_points.squeeze(0) * scale).round().long() - query_points_scaled = query_points_scaled.cpu().numpy() - - pred_conf = conf[query_index][ - query_points_scaled[:, 1], query_points_scaled[:, 0] - ] - pred_point_3d = points_3d[query_index][ - query_points_scaled[:, 1], query_points_scaled[:, 0] - ] - - # heuristic to remove low confidence points - # should I export this as an input parameter? - valid_mask = pred_conf > 1.2 - if valid_mask.sum() > 512: - query_points = query_points[:, valid_mask] # Make sure shape is compatible - pred_conf = pred_conf[valid_mask] - pred_point_3d = pred_point_3d[valid_mask] - pred_color = pred_color[valid_mask] - else: - pred_conf = None - pred_point_3d = None - - reorder_index = calculate_index_mappings(query_index, frame_num, device=device) - - images_feed, fmaps_feed = switch_tensor_order( - [images, fmaps_for_tracker], reorder_index, dim=0 - ) - images_feed = images_feed[None] # add batch dimension - fmaps_feed = fmaps_feed[None] # add batch dimension - - all_points_num = images_feed.shape[1] * query_points.shape[1] - - # Don't need to be scared, this is just chunking to make GPU happy - if all_points_num > max_points_num: - num_splits = (all_points_num + max_points_num - 1) // max_points_num - query_points = torch.chunk(query_points, num_splits, dim=1) - else: - query_points = [query_points] - - pred_track, pred_vis, _ = predict_tracks_in_chunks( - tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking - ) - - pred_track, pred_vis = switch_tensor_order( - [pred_track, pred_vis], reorder_index, dim=1 - ) - - pred_track = pred_track.squeeze(0).float().cpu().numpy() - pred_vis = pred_vis.squeeze(0).float().cpu().numpy() - - return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color - - -def _augment_non_visible_frames( - pred_tracks: list, # ← running list of np.ndarrays - pred_vis_scores: list, # ← running list of np.ndarrays - pred_confs: list, # ← running list of np.ndarrays for confidence scores - pred_points_3d: list, # ← running list of np.ndarrays for 3D points - pred_colors: list, # ← running list of np.ndarrays for colors - images: torch.Tensor, - conf, - points_3d, - fmaps_for_tracker, - keypoint_extractors, - tracker, - max_points_num: int, - fine_tracking: bool, - *, - min_vis: int = 500, - non_vis_thresh: float = 0.1, - device: torch.device = None, -): - """ - Augment tracking for frames with insufficient visibility. - - Args: - pred_tracks: List of numpy arrays containing predicted tracks. - pred_vis_scores: List of numpy arrays containing visibility scores. - pred_confs: List of numpy arrays containing confidence scores. - pred_points_3d: List of numpy arrays containing 3D points. - pred_colors: List of numpy arrays containing point colors. - images: Tensor of shape [S, 3, H, W] containing the input images. - conf: Tensor of shape [S, 1, H, W] containing confidence scores - points_3d: Tensor containing 3D points - fmaps_for_tracker: Feature maps for the tracker - keypoint_extractors: Initialized feature extractors - tracker: VGG-SFM tracker - max_points_num: Maximum number of points to process at once - fine_tracking: Whether to use fine tracking - min_vis: Minimum visibility threshold - non_vis_thresh: Non-visibility threshold - device: Device to use for computation - - Returns: - Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists. - """ - last_query = -1 - final_trial = False - cur_extractors = keypoint_extractors # may be replaced on the final trial - - while True: - # Visibility per frame - vis_array = np.concatenate(pred_vis_scores, axis=1) - - # Count frames with sufficient visibility using numpy - sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1) - non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist() - - if len(non_vis_frames) == 0: - break - - print("Processing non visible frames:", non_vis_frames) - - # Decide the frames & extractor for this round - if non_vis_frames[0] == last_query: - # Same frame failed twice - final "all-in" attempt - final_trial = True - cur_extractors = initialize_feature_extractors( - 2048, extractor_method="sp+sift+aliked", device=device - ) - query_frame_list = non_vis_frames # blast them all at once - else: - query_frame_list = [non_vis_frames[0]] # Process one at a time - - last_query = non_vis_frames[0] - - # Run the tracker for every selected frame - for query_index in query_frame_list: - new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query( - query_index, - images, - conf, - points_3d, - fmaps_for_tracker, - cur_extractors, - tracker, - max_points_num, - fine_tracking, - device, - ) - pred_tracks.append(new_track) - pred_vis_scores.append(new_vis) - pred_confs.append(new_conf) - pred_points_3d.append(new_point_3d) - pred_colors.append(new_color) - - if final_trial: - break # Stop after final attempt - - return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors diff --git a/mapanything/third_party/vggsfm_tracker.py b/mapanything/third_party/vggsfm_tracker.py deleted file mode 100644 index 039b3e3b..00000000 --- a/mapanything/third_party/vggsfm_tracker.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -# Modified from https://github.com/facebookresearch/vggt - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .track_modules.base_track_predictor import BaseTrackerPredictor -from .track_modules.blocks import BasicEncoder, ShallowEncoder -from .track_modules.track_refine import refine_track - - -class TrackerPredictor(nn.Module): - def __init__(self, **extra_args): - super(TrackerPredictor, self).__init__() - """ - Initializes the tracker predictor. - - Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor, - check track_modules/base_track_predictor.py - - Both coarse_fnet and fine_fnet are constructed as a 2D CNN network - check track_modules/blocks.py for BasicEncoder and ShallowEncoder - """ - # Define coarse predictor configuration - coarse_stride = 4 - self.coarse_down_ratio = 2 - - # Create networks directly instead of using instantiate - self.coarse_fnet = BasicEncoder(stride=coarse_stride) - self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride) - - # Create fine predictor with stride = 1 - self.fine_fnet = ShallowEncoder(stride=1) - self.fine_predictor = BaseTrackerPredictor( - stride=1, - depth=4, - corr_levels=3, - corr_radius=3, - latent_dim=32, - hidden_size=256, - fine=True, - use_spaceatt=False, - ) - - def forward( - self, - images, - query_points, - fmaps=None, - coarse_iters=6, - inference=True, - fine_tracking=True, - fine_chunk=40960, - ): - """ - Args: - images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W. - query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2. - fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None. - coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6. - inference (bool, optional): Whether to perform inference. Defaults to True. - fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True. - - Returns: - tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score. - """ - - if fmaps is None: - batch_num, frame_num, image_dim, height, width = images.shape - reshaped_image = images.reshape( - batch_num * frame_num, image_dim, height, width - ) - fmaps = self.process_images_to_fmaps(reshaped_image) - fmaps = fmaps.reshape( - batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1] - ) - - if inference: - torch.cuda.empty_cache() - - # Coarse prediction - coarse_pred_track_lists, pred_vis = self.coarse_predictor( - query_points=query_points, - fmaps=fmaps, - iters=coarse_iters, - down_ratio=self.coarse_down_ratio, - ) - coarse_pred_track = coarse_pred_track_lists[-1] - - if inference: - torch.cuda.empty_cache() - - if fine_tracking: - # Refine the coarse prediction - fine_pred_track, pred_score = refine_track( - images, - self.fine_fnet, - self.fine_predictor, - coarse_pred_track, - compute_score=False, - chunk=fine_chunk, - ) - - if inference: - torch.cuda.empty_cache() - else: - fine_pred_track = coarse_pred_track - pred_score = torch.ones_like(pred_vis) - - return fine_pred_track, coarse_pred_track, pred_vis, pred_score - - def process_images_to_fmaps(self, images): - """ - This function processes images for inference. - - Args: - images (torch.Tensor): The images to be processed with shape S x 3 x H x W. - - Returns: - torch.Tensor: The processed feature maps. - """ - if self.coarse_down_ratio > 1: - # whether or not scale down the input images to save memory - fmaps = self.coarse_fnet( - F.interpolate( - images, - scale_factor=1 / self.coarse_down_ratio, - mode="bilinear", - align_corners=True, - ) - ) - else: - fmaps = self.coarse_fnet(images) - - return fmaps diff --git a/mapanything/third_party/vggsfm_utils.py b/mapanything/third_party/vggsfm_utils.py deleted file mode 100644 index 22eb1170..00000000 --- a/mapanything/third_party/vggsfm_utils.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -# Modified from https://github.com/facebookresearch/vggt - -import logging -import warnings - -import torch -import torch.nn.functional as F -from lightglue import ALIKED, SIFT, SuperPoint - -from .vggsfm_tracker import TrackerPredictor - -# Suppress verbose logging from dependencies -logging.getLogger("dinov2").setLevel(logging.WARNING) -warnings.filterwarnings("ignore", message="xFormers is available") -warnings.filterwarnings("ignore", message="dinov2") - -# Constants -_RESNET_MEAN = [0.485, 0.456, 0.406] -_RESNET_STD = [0.229, 0.224, 0.225] - - -def build_vggsfm_tracker(model_path=None): - """ - Build and initialize the VGGSfM tracker. - - Args: - model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace. - - Returns: - Initialized tracker model in eval mode. - """ - tracker = TrackerPredictor() - - if model_path is None: - default_url = ( - "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt" - ) - tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url)) - else: - tracker.load_state_dict(torch.load(model_path)) - - tracker.eval() - return tracker - - -def generate_rank_by_dino( - images, - query_frame_num, - image_size=336, - model_name="dinov2_vitb14_reg", - device="cuda", - spatial_similarity=False, -): - """ - Generate a ranking of frames using DINO ViT features. - - Args: - images: Tensor of shape (S, 3, H, W) with values in range [0, 1] - query_frame_num: Number of frames to select - image_size: Size to resize images to before processing - model_name: Name of the DINO model to use - device: Device to run the model on - spatial_similarity: Whether to use spatial token similarity or CLS token similarity - - Returns: - List of frame indices ranked by their representativeness - """ - # Resize images to the target size - images = F.interpolate( - images, (image_size, image_size), mode="bilinear", align_corners=False - ) - - # Load DINO model - dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name) - dino_v2_model.eval() - dino_v2_model = dino_v2_model.to(device) - - # Normalize images using ResNet normalization - resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) - resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) - images_resnet_norm = (images - resnet_mean) / resnet_std - - with torch.no_grad(): - frame_feat = dino_v2_model(images_resnet_norm, is_training=True) - - # Process features based on similarity type - if spatial_similarity: - frame_feat = frame_feat["x_norm_patchtokens"] - frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) - - # Compute the similarity matrix - frame_feat_norm = frame_feat_norm.permute(1, 0, 2) - similarity_matrix = torch.bmm( - frame_feat_norm, frame_feat_norm.transpose(-1, -2) - ) - similarity_matrix = similarity_matrix.mean(dim=0) - else: - frame_feat = frame_feat["x_norm_clstoken"] - frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) - similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) - - distance_matrix = 100 - similarity_matrix.clone() - - # Ignore self-pairing - similarity_matrix.fill_diagonal_(-100) - similarity_sum = similarity_matrix.sum(dim=1) - - # Find the most common frame - most_common_frame_index = torch.argmax(similarity_sum).item() - - # Conduct FPS sampling starting from the most common frame - fps_idx = farthest_point_sampling( - distance_matrix, query_frame_num, most_common_frame_index - ) - - # Clean up all tensors and models to free memory - del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix - del dino_v2_model - torch.cuda.empty_cache() - - return fps_idx - - -def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0): - """ - Farthest point sampling algorithm to select diverse frames. - - Args: - distance_matrix: Matrix of distances between frames - num_samples: Number of frames to select - most_common_frame_index: Index of the first frame to select - - Returns: - List of selected frame indices - """ - distance_matrix = distance_matrix.clamp(min=0) - N = distance_matrix.size(0) - - # Initialize with the most common frame - selected_indices = [most_common_frame_index] - check_distances = distance_matrix[selected_indices] - - while len(selected_indices) < num_samples: - # Find the farthest point from the current set of selected points - farthest_point = torch.argmax(check_distances) - selected_indices.append(farthest_point.item()) - - check_distances = distance_matrix[farthest_point] - # Mark already selected points to avoid selecting them again - check_distances[selected_indices] = 0 - - # Break if all points have been selected - if len(selected_indices) == N: - break - - return selected_indices - - -def calculate_index_mappings(query_index, S, device=None): - """ - Construct an order that switches [query_index] and [0] - so that the content of query_index would be placed at [0]. - - Args: - query_index: Index to swap with 0 - S: Total number of elements - device: Device to place the tensor on - - Returns: - Tensor of indices with the swapped order - """ - new_order = torch.arange(S) - new_order[0] = query_index - new_order[query_index] = 0 - if device is not None: - new_order = new_order.to(device) - return new_order - - -def switch_tensor_order(tensors, order, dim=1): - """ - Reorder tensors along a specific dimension according to the given order. - - Args: - tensors: List of tensors to reorder - order: Tensor of indices specifying the new order - dim: Dimension along which to reorder - - Returns: - List of reordered tensors - """ - return [ - torch.index_select(tensor, dim, order) if tensor is not None else None - for tensor in tensors - ] - - -def initialize_feature_extractors( - max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda" -): - """ - Initialize feature extractors that can be reused based on a method string. - - Args: - max_query_num: Maximum number of keypoints to extract - det_thres: Detection threshold for keypoint extraction - extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") - device: Device to run extraction on - - Returns: - Dictionary of initialized extractors - """ - extractors = {} - methods = extractor_method.lower().split("+") - - for method in methods: - method = method.strip() - if method == "aliked": - aliked_extractor = ALIKED( - max_num_keypoints=max_query_num, detection_threshold=det_thres - ) - extractors["aliked"] = aliked_extractor.to(device).eval() - elif method == "sp": - sp_extractor = SuperPoint( - max_num_keypoints=max_query_num, detection_threshold=det_thres - ) - extractors["sp"] = sp_extractor.to(device).eval() - elif method == "sift": - sift_extractor = SIFT(max_num_keypoints=max_query_num) - extractors["sift"] = sift_extractor.to(device).eval() - else: - print(f"Warning: Unknown feature extractor '{method}', ignoring.") - - if not extractors: - print( - f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default." - ) - aliked_extractor = ALIKED( - max_num_keypoints=max_query_num, detection_threshold=det_thres - ) - extractors["aliked"] = aliked_extractor.to(device).eval() - - return extractors - - -def extract_keypoints(query_image, extractors, round_keypoints=True): - """ - Extract keypoints using pre-initialized feature extractors. - - Args: - query_image: Input image tensor (3xHxW, range [0, 1]) - extractors: Dictionary of initialized extractors - - Returns: - Tensor of keypoint coordinates (1xNx2) - """ - query_points = None - - with torch.no_grad(): - for extractor_name, extractor in extractors.items(): - query_points_data = extractor.extract(query_image, invalid_mask=None) - extractor_points = query_points_data["keypoints"] - if round_keypoints: - extractor_points = extractor_points.round() - - if query_points is not None: - query_points = torch.cat([query_points, extractor_points], dim=1) - else: - query_points = extractor_points - - return query_points - - -def predict_tracks_in_chunks( - track_predictor, - images_feed, - query_points_list, - fmaps_feed, - fine_tracking, - num_splits=None, - fine_chunk=40960, -): - """ - Process a list of query points to avoid memory issues. - - Args: - track_predictor (object): The track predictor object used for predicting tracks. - images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images. - query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points. - fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker. - fine_tracking (bool): Whether to perform fine tracking. - num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility. - - Returns: - tuple: A tuple containing the concatenated predicted tracks, visibility, and scores. - """ - # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility - if not isinstance(query_points_list, (list, tuple)): - query_points = query_points_list - if num_splits is None: - num_splits = 1 - query_points_list = torch.chunk(query_points, num_splits, dim=1) - - # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple) - if isinstance(query_points_list, tuple): - query_points_list = list(query_points_list) - - fine_pred_track_list = [] - pred_vis_list = [] - pred_score_list = [] - - for split_points in query_points_list: - # Feed into track predictor for each split - fine_pred_track, _, pred_vis, pred_score = track_predictor( - images_feed, - split_points, - fmaps=fmaps_feed, - fine_tracking=fine_tracking, - fine_chunk=fine_chunk, - ) - fine_pred_track_list.append(fine_pred_track) - pred_vis_list.append(pred_vis) - pred_score_list.append(pred_score) - - # Concatenate the results from all splits - fine_pred_track = torch.cat(fine_pred_track_list, dim=2) - pred_vis = torch.cat(pred_vis_list, dim=2) - - if pred_score is not None: - pred_score = torch.cat(pred_score_list, dim=2) - else: - pred_score = None - - return fine_pred_track, pred_vis, pred_score diff --git a/scripts/demo_colmap.py b/scripts/demo_colmap.py index 2cf103d2..a83458f6 100644 --- a/scripts/demo_colmap.py +++ b/scripts/demo_colmap.py @@ -4,7 +4,7 @@ # found in the LICENSE file in the root directory of this source tree. """ -Demo script to get MapAnything outputs in COLMAP format. Optionally can also run BA on outputs. +Demo script to get MapAnything outputs in COLMAP format. Reference: VGGT (https://github.com/facebookresearch/vggt/blob/main/demo_colmap.py) """ @@ -13,6 +13,9 @@ import copy import glob import os +from PIL.ImageOps import exif_transpose +import PIL +import tqdm os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" @@ -29,9 +32,9 @@ batch_np_matrix_to_pycolmap, batch_np_matrix_to_pycolmap_wo_track, ) -from mapanything.third_party.track_predict import predict_tracks from mapanything.utils.geometry import closed_form_pose_inverse, depthmap_to_world_frame -from mapanything.utils.image import rgb +from mapanything.utils.image import rgb, load_images, find_closest_aspect_ratio +from mapanything.utils.cropping import rescale_image_and_other_optional_info, crop_image_and_other_optional_info from mapanything.utils.misc import seed_everything from mapanything.utils.viz import predictions_to_glb from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT @@ -45,10 +48,16 @@ def parse_args(): parser = argparse.ArgumentParser(description="MapAnything COLMAP Demo") parser.add_argument( - "--scene_dir", + "--images_dir", type=str, required=True, - help="Directory containing the scene images", + help="Directory containing input images", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory to save COLMAP outputs (defaults to images_dir parent)", ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for reproducibility" @@ -60,185 +69,19 @@ def parse_args(): help="Use memory efficient inference for reconstruction (trades off speed)", ) parser.add_argument( - "--conf_thres_value", + "--conf_percentile", type=float, - default=0.0, - help="Confidence threshold value for depth filtering (used only without BA)", - ) - parser.add_argument( - "--save_glb", - action="store_true", - default=False, - help="Save dense reconstruction (without BA) as GLB file", - ) - parser.add_argument( - "--use_ba", action="store_true", default=False, help="Use BA for reconstruction" - ) - ######### BA parameters ######### - parser.add_argument( - "--max_reproj_error", - type=float, - default=8.0, - help="Maximum reprojection error for reconstruction", - ) - parser.add_argument( - "--shared_camera", - action="store_true", - default=False, - help="Use shared camera for all images", - ) - parser.add_argument( - "--camera_type", - type=str, - default="SIMPLE_PINHOLE", - help="Camera type for reconstruction", - ) - parser.add_argument( - "--vis_thresh", type=float, default=0.2, help="Visibility threshold for tracks" - ) - parser.add_argument( - "--query_frame_num", type=int, default=8, help="Number of frames to query" - ) - parser.add_argument( - "--max_query_pts", type=int, default=4096, help="Maximum number of query points" + default=10, + help="The percentile to use for the confidence threshold for depth filtering. Defaults to 10.", ) parser.add_argument( - "--fine_tracking", + "--apache", action="store_true", - default=True, - help="Use fine tracking (slower but more accurate)", + help="Use Apache 2.0 licensed model (facebook/map-anything-apache)", ) return parser.parse_args() -def load_and_preprocess_images_square( - image_path_list, target_size=1024, data_norm_type=None -): - """ - Load and preprocess images by center padding to square and resizing to target size. - Also returns the position information of original pixels after transformation. - - Args: - image_path_list (list): List of paths to image files - target_size (int, optional): Target size for both width and height. Defaults to 1024. - data_norm_type (str, optional): Image normalization type. See UniCeption IMAGE_NORMALIZATION_DICT keys. Defaults to None (no normalization). - - Returns: - tuple: ( - torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size), - torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image - ) - - Raises: - ValueError: If the input list is empty or if an invalid data_norm_type is provided - """ - # Check for empty list - if len(image_path_list) == 0: - raise ValueError("At least 1 image is required") - - images = [] - original_coords = [] # Renamed from position_info to be more descriptive - - # Set up normalization based on data_norm_type - if data_norm_type is None: - # No normalization, just convert to tensor - img_transform = tvf.ToTensor() - elif data_norm_type in IMAGE_NORMALIZATION_DICT.keys(): - # Use the specified normalization - img_norm = IMAGE_NORMALIZATION_DICT[data_norm_type] - img_transform = tvf.Compose( - [tvf.ToTensor(), tvf.Normalize(mean=img_norm.mean, std=img_norm.std)] - ) - else: - raise ValueError( - f"Unknown image normalization type: {data_norm_type}. Available options: {list(IMAGE_NORMALIZATION_DICT.keys())}" - ) - - for image_path in image_path_list: - # Open image - img = Image.open(image_path) - - # If there's an alpha channel, blend onto white background - if img.mode == "RGBA": - background = Image.new("RGBA", img.size, (255, 255, 255, 255)) - img = Image.alpha_composite(background, img) - - # Convert to RGB - img = img.convert("RGB") - - # Get original dimensions - width, height = img.size - - # Make the image square by padding the shorter dimension - max_dim = max(width, height) - - # Calculate padding - left = (max_dim - width) // 2 - top = (max_dim - height) // 2 - - # Calculate scale factor for resizing - scale = target_size / max_dim - - # Calculate final coordinates of original image in target space - x1 = left * scale - y1 = top * scale - x2 = (left + width) * scale - y2 = (top + height) * scale - - # Store original image coordinates and scale - original_coords.append(np.array([x1, y1, x2, y2, width, height])) - - # Create a new black square image and paste original - square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) - square_img.paste(img, (left, top)) - - # Resize to target size - square_img = square_img.resize( - (target_size, target_size), Image.Resampling.BICUBIC - ) - - # Convert to tensor and apply normalization - img_tensor = img_transform(square_img) - images.append(img_tensor) - - # Stack all images - images = torch.stack(images) - original_coords = torch.from_numpy(np.array(original_coords)).float() - - # Add additional dimension if single image to ensure correct shape - if len(image_path_list) == 1: - if images.dim() == 3: - images = images.unsqueeze(0) - original_coords = original_coords.unsqueeze(0) - - return images, original_coords - - -def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray: - """ - If mask has more than max_trues True values, - randomly keep only max_trues of them and set the rest to False. - """ - # 1D positions of all True entries - true_indices = np.flatnonzero(mask) # shape = (N_true,) - - # if already within budget, return as-is - if true_indices.size <= max_trues: - return mask - - # randomly pick which True positions to keep - sampled_indices = np.random.choice( - true_indices, size=max_trues, replace=False - ) # shape = (max_trues,) - - # build new flat mask: True only at sampled positions - limited_flat_mask = np.zeros(mask.size, dtype=bool) - limited_flat_mask[sampled_indices] = True - - # restore original shape - return limited_flat_mask.reshape(mask.shape) - - def create_pixel_coordinate_grid(num_frames, height, width): """ Creates a grid of pixel coordinates and frame indices for all frames. @@ -269,374 +112,161 @@ def create_pixel_coordinate_grid(num_frames, height, width): return points_xyf -def run_mapanything( - model, - images, - dtype, - resolution=518, - image_normalization_type="dinov2", - memory_efficient_inference=False, -): - # Images: [V, 3, H, W] - # Check image shape - assert len(images.shape) == 4 - assert images.shape[1] == 3 - - # Hard-coded to use 518 for MapAnything - images = F.interpolate( - images, size=(resolution, resolution), mode="bilinear", align_corners=False - ) - - # Run inference - views = [] - for view_idx in range(images.shape[0]): - view = { - "img": images[view_idx][None], # Add batch dimension - "data_norm_type": [image_normalization_type], - } - views.append(view) - predictions = model.infer( - views, memory_efficient_inference=memory_efficient_inference - ) +def demo_fn(model, device, dtype, args, images_dir, output_dir): + # Print configuration + print(f"\nProcessing images from: {images_dir}") - # Process predictions - ( - all_extrinsics, - all_intrinsics, - all_depth_maps, - all_depth_confs, - all_pts3d, - all_img_no_norm, - all_masks, - ) = ( - [], - [], - [], - [], - [], - [], - [], - ) - for pred in predictions: - # Compute 3D points from depth, intrinsics, and camera pose - depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W) - intrinsics_torch = pred["intrinsics"][0] # (3, 3) - camera_pose_torch = pred["camera_poses"][0] # (4, 4) - pts3d, valid_mask = depthmap_to_world_frame( - depthmap_torch, intrinsics_torch, camera_pose_torch - ) - - # Extract mask from predictions and combine with valid depth mask - mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) - mask = mask & valid_mask.cpu().numpy() # Combine with valid depth mask - - # Convert tensors to numpy arrays - extrinsic = ( - closed_form_pose_inverse(pred["camera_poses"])[0].cpu().numpy() - ) # c2w -> w2c - intrinsic = intrinsics_torch.cpu().numpy() - depth_map = depthmap_torch.cpu().numpy() - depth_conf = pred["conf"][0].cpu().numpy() - pts3d = pts3d.cpu().numpy() - img_no_norm = pred["img_no_norm"][0].cpu().numpy() # Denormalized image - - # Collect results - all_extrinsics.append(extrinsic) - all_intrinsics.append(intrinsic) - all_depth_maps.append(depth_map) - all_depth_confs.append(depth_conf) - all_pts3d.append(pts3d) - all_img_no_norm.append(img_no_norm) - all_masks.append(mask) - - # Stack results into arrays - all_extrinsics = np.stack(all_extrinsics) - all_intrinsics = np.stack(all_intrinsics) - all_depth_maps = np.stack(all_depth_maps) - all_depth_confs = np.stack(all_depth_confs) - all_pts3d = np.stack(all_pts3d) - all_img_no_norm = np.stack(all_img_no_norm) - all_masks = np.stack(all_masks) - - return ( - all_extrinsics, - all_intrinsics, - all_depth_maps, - all_depth_confs, - all_pts3d, - all_img_no_norm, - all_masks, - ) + sparse_reconstruction_dir = os.path.join(output_dir, "sparse") + if os.path.exists(sparse_reconstruction_dir): + print(f"Reconstruction already exists at {sparse_reconstruction_dir}, skipping...") + return True -def demo_fn(args): - # Print configuration - print("Arguments:", vars(args)) + # Get image paths and preprocess them + image_path_list = glob.glob(os.path.join(images_dir, "*")) + if len(image_path_list) == 0: + raise ValueError(f"No images found in {images_dir}") + base_image_path_list = [os.path.basename(path) for path in image_path_list] - # Set seed for reproducibility - seed_everything(args.seed) + images = load_images(image_path_list) - # Set device and dtype - dtype = ( - torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + outputs = model.infer( + images, memory_efficient_inference=args.memory_efficient_inference, confidence_percentile=args.conf_percentile ) - device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"Using device: {device}") - print(f"Using dtype: {dtype}") - # Init model - print("Loading MapAnything model from huggingface ...") - model = MapAnything.from_pretrained("facebook/map-anything").to(device) - model.eval() + intrinsic = np.stack([outputs[i]["intrinsics"][0].cpu().numpy() for i in range(len(outputs))]) + extrinsic = np.stack([closed_form_pose_inverse(outputs[i]["camera_poses"])[0].cpu().numpy() for i in range(len(outputs))]) + points_3d = np.stack([outputs[i]["pts3d"][0].cpu().numpy() for i in range(len(outputs))]) + images = np.stack([images[i]["img"][0].cpu().numpy() for i in range(len(images))]) - # Get image paths and preprocess them - image_dir = os.path.join(args.scene_dir, "images") - image_path_list = glob.glob(os.path.join(image_dir, "*")) - if len(image_path_list) == 0: - raise ValueError(f"No images found in {image_dir}") - base_image_path_list = [os.path.basename(path) for path in image_path_list] + shared_camera = ( + False # in the feedforward manner, we do not support shared camera + ) + camera_type = ( + "PINHOLE" # in the feedforward manner, we only support PINHOLE camera + ) - # Load images and original coordinates - # Load Image in 1024, while running MapAnything with 518 - mapanything_fixed_resolution = 518 - img_load_resolution = 1024 + num_frames, height, width, _ = points_3d.shape + + image_size = np.array([width, height]) - images, original_coords = load_and_preprocess_images_square( - image_path_list, img_load_resolution, model.encoder.data_norm_type + # Denormalize images before computing RGB values + points_rgb_images = F.interpolate( + torch.from_numpy(images).to(torch.float32), + size=(height, width), + mode="bilinear", + align_corners=False, ) - images = images.to(device) - original_coords = original_coords.to(device) - print(f"Loaded {len(images)} images from {image_dir}") - - # Run MapAnything to estimate camera and depth - # Run with 518 x 518 images - extrinsic, intrinsic, depth_map, depth_conf, points_3d, img_no_norm, masks = ( - run_mapanything( - model, - images, - dtype, - mapanything_fixed_resolution, - model.encoder.data_norm_type, - memory_efficient_inference=args.memory_efficient_inference, - ) + + # Convert normalized images back to RGB [0,1] range using the rgb function + points_rgb_list = [] + for i in range(points_rgb_images.shape[0]): + # rgb function expects single image tensor and returns numpy array in [0,1] range + rgb_img = rgb(points_rgb_images[i], model.encoder.data_norm_type) + points_rgb_list.append(rgb_img) + + # Stack and convert to uint8 + points_rgb = np.stack(points_rgb_list) # Shape: (N, H, W, 3) + points_rgb = (points_rgb * 255).astype(np.uint8) + + # (S, H, W, 3), with x, y coordinates and frame indices + points_xyf = create_pixel_coordinate_grid(num_frames, height, width) + + # Filter points based on zero depth + valid_mask = points_3d[..., 2] > 0 + + points_3d = points_3d[valid_mask] + points_xyf = points_xyf[valid_mask] + points_rgb = points_rgb[valid_mask] + + print("Converting to COLMAP format") + reconstruction = batch_np_matrix_to_pycolmap_wo_track( + points_3d, + points_xyf, + points_rgb, + extrinsic, + intrinsic, + image_size, + shared_camera=shared_camera, + camera_type=camera_type, ) - # Prepare lists for GLB export if needed - world_points_list = [] - images_list = [] - masks_list = [] - - if args.save_glb: - for i in range(img_no_norm.shape[0]): - # Use the already denormalized images from predictions - images_list.append(img_no_norm[i]) - - # Add world points and masks from predictions - world_points_list.append(points_3d[i]) - masks_list.append(masks[i]) # Use masks from predictions - - if args.use_ba: - image_size = np.array(images.shape[-2:]) - scale = img_load_resolution / mapanything_fixed_resolution - shared_camera = args.shared_camera - - with torch.amp.autocast("cuda", dtype=dtype): - # Predicting Tracks - # Uses VGGSfM tracker - # You can also change the pred_tracks to tracks from any other methods - # e.g., from COLMAP, from CoTracker, or by chaining 2D matches from Lightglue/LoFTR. - pred_tracks, pred_vis_scores, pred_confs, points_3d, points_rgb = ( - predict_tracks( - images, - conf=depth_conf, - points_3d=points_3d, - max_query_pts=args.max_query_pts, - query_frame_num=args.query_frame_num, - keypoint_extractor="aliked+sp", - fine_tracking=args.fine_tracking, - ) - ) - - torch.cuda.empty_cache() - - # Rescale the intrinsic matrix from 518 to 1024 - intrinsic[:, :2, :] *= scale - track_mask = pred_vis_scores > args.vis_thresh - - # Init pycolmap reconstruction - reconstruction, valid_track_mask = batch_np_matrix_to_pycolmap( - points_3d, - extrinsic, - intrinsic, - pred_tracks, - image_size, - masks=track_mask, - max_reproj_error=args.max_reproj_error, - shared_camera=shared_camera, - camera_type=args.camera_type, - points_rgb=points_rgb, - ) - - if reconstruction is None: - raise ValueError("No reconstruction can be built with BA") - - # Bundle Adjustment - ba_options = pycolmap.BundleAdjustmentOptions() - pycolmap.bundle_adjustment(reconstruction, ba_options) - - reconstruction_resolution = img_load_resolution - else: - conf_thres_value = args.conf_thres_value - max_points_for_colmap = 100000 # randomly sample 3D points - shared_camera = ( - False # in the feedforward manner, we do not support shared camera - ) - camera_type = ( - "PINHOLE" # in the feedforward manner, we only support PINHOLE camera - ) - - image_size = np.array( - [mapanything_fixed_resolution, mapanything_fixed_resolution] - ) - num_frames, height, width, _ = points_3d.shape - - # Denormalize images before computing RGB values - points_rgb_images = F.interpolate( - images, - size=(mapanything_fixed_resolution, mapanything_fixed_resolution), - mode="bilinear", - align_corners=False, - ) - - # Convert normalized images back to RGB [0,1] range using the rgb function - points_rgb_list = [] - for i in range(points_rgb_images.shape[0]): - # rgb function expects single image tensor and returns numpy array in [0,1] range - rgb_img = rgb(points_rgb_images[i], model.encoder.data_norm_type) - points_rgb_list.append(rgb_img) - - # Stack and convert to uint8 - points_rgb = np.stack(points_rgb_list) # Shape: (N, H, W, 3) - points_rgb = (points_rgb * 255).astype(np.uint8) - - # (S, H, W, 3), with x, y coordinates and frame indices - points_xyf = create_pixel_coordinate_grid(num_frames, height, width) - - conf_mask = depth_conf >= conf_thres_value - # At most writing 100000 3d points to colmap reconstruction object - conf_mask = randomly_limit_trues(conf_mask, max_points_for_colmap) - - points_3d = points_3d[conf_mask] - points_xyf = points_xyf[conf_mask] - points_rgb = points_rgb[conf_mask] - - print("Converting to COLMAP format") - reconstruction = batch_np_matrix_to_pycolmap_wo_track( - points_3d, - points_xyf, - points_rgb, - extrinsic, - intrinsic, - image_size, - shared_camera=shared_camera, - camera_type=camera_type, - ) - - reconstruction_resolution = mapanything_fixed_resolution - - reconstruction = rename_colmap_recons_and_rescale_camera( + reconstruction = rename_colmap_recons( reconstruction, base_image_path_list, - original_coords.cpu().numpy(), - img_size=reconstruction_resolution, - shift_point2d_to_original_res=True, - shared_camera=shared_camera, ) - print(f"Saving reconstruction to {args.scene_dir}/sparse") - sparse_reconstruction_dir = os.path.join(args.scene_dir, "sparse") + print(f"Saving reconstruction to {output_dir}/sparse") + os.makedirs(sparse_reconstruction_dir, exist_ok=True) reconstruction.write(sparse_reconstruction_dir) # Save point cloud for fast visualization trimesh.PointCloud(points_3d, colors=points_rgb).export( - os.path.join(args.scene_dir, "sparse/points.ply") + os.path.join(output_dir, "sparse/points.ply") ) - - # Export GLB if requested - if args.save_glb: - glb_output_path = os.path.join(args.scene_dir, "dense_mesh.glb") - print(f"Saving GLB file to: {glb_output_path}") - - # Stack all views - world_points = np.stack(world_points_list, axis=0) - images = np.stack(images_list, axis=0) - final_masks = np.stack(masks_list, axis=0) - - # Create predictions dict for GLB export - predictions = { - "world_points": world_points, - "images": images, - "final_masks": final_masks, - } - - # Convert to GLB scene - scene_3d = predictions_to_glb(predictions, as_mesh=True) - - # Save GLB file - scene_3d.export(glb_output_path) - print(f"Successfully saved GLB file: {glb_output_path}") - return True -def rename_colmap_recons_and_rescale_camera( +def rename_colmap_recons( reconstruction, image_paths, - original_coords, - img_size, - shift_point2d_to_original_res=False, - shared_camera=False, -): - rescale_camera = True - +): + """Rename COLMAP reconstruction images to original names.""" for pyimageid in reconstruction.images: # Reshaped the padded & resized image to the original size # Rename the images to the original names pyimage = reconstruction.images[pyimageid] - pycamera = reconstruction.cameras[pyimage.camera_id] pyimage.name = image_paths[pyimageid - 1] - if rescale_camera: - # Rescale the camera parameters - pred_params = copy.deepcopy(pycamera.params) - - real_image_size = original_coords[pyimageid - 1, -2:] - resize_ratio = max(real_image_size) / img_size - pred_params = pred_params * resize_ratio - real_pp = real_image_size / 2 - pred_params[-2:] = real_pp # center of the image - - pycamera.params = pred_params - pycamera.width = real_image_size[0] - pycamera.height = real_image_size[1] + return reconstruction - if shift_point2d_to_original_res: - # Also shift the point2D to original resolution - top_left = original_coords[pyimageid - 1, :2] - for point2D in pyimage.points2D: - point2D.xy = (point2D.xy - top_left) * resize_ratio +if __name__ == "__main__": + args = parse_args() - if shared_camera: - # If shared_camera, all images share the same camera - # No need to rescale any more - rescale_camera = False + # Set seed for reproducibility + seed_everything(args.seed) - return reconstruction + # Set device and dtype + dtype = ( + torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + ) + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + print(f"Using dtype: {dtype}") + # Init model + print("Loading MapAnything model from huggingface ...") + if args.apache: + model_name = "facebook/map-anything-apache" + print("Loading Apache 2.0 licensed MapAnything model...") + else: + model_name = "facebook/map-anything" + print("Loading CC-BY-NC 4.0 licensed MapAnything model...") + + model = MapAnything.from_pretrained(model_name).to(device) + model.eval() -if __name__ == "__main__": - args = parse_args() - with torch.no_grad(): - demo_fn(args) + # Validate images directory + if not os.path.isdir(args.images_dir): + raise ValueError(f"Images directory not found: {args.images_dir}") + + # Set output directory + if args.output_dir is None: + args.output_dir = os.path.dirname(args.images_dir) + + os.makedirs(args.output_dir, exist_ok=True) + + # Process images + try: + with torch.no_grad(): + success = demo_fn(model, device, dtype, args, args.images_dir, args.output_dir) + if success: + print(f"\n✅ Successfully processed images from {args.images_dir}") + else: + print(f"\n❌ Processing failed") + except Exception as e: + print(f"\n❌ Error processing images: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/scripts/demo_inference_on_colmap_outputs.py b/scripts/demo_inference_on_colmap_outputs.py index 60875218..1af216a9 100644 --- a/scripts/demo_inference_on_colmap_outputs.py +++ b/scripts/demo_inference_on_colmap_outputs.py @@ -21,23 +21,28 @@ """ import argparse +import glob import os +import torch.nn.functional as F os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import numpy as np +import pycolmap import rerun as rr import torch +import trimesh from PIL import Image from mapanything.models import MapAnything from mapanything.utils.colmap import get_camera_matrix, qvec2rotmat, read_model from mapanything.utils.geometry import closed_form_pose_inverse, depthmap_to_world_frame -from mapanything.utils.image import preprocess_inputs +from mapanything.utils.image import preprocess_inputs, rgb from mapanything.utils.viz import predictions_to_glb, script_add_rerun_args +from mapanything.third_party.np_to_pycolmap import batch_np_matrix_to_pycolmap_wo_track -def load_colmap_data(colmap_path, stride=1, verbose=False, ext=".bin"): +def load_colmap_data(images_path, sparse_path, stride=1, verbose=False, ext=".bin"): """ Load COLMAP format data for MapAnything inference. @@ -53,35 +58,31 @@ def load_colmap_data(colmap_path, stride=1, verbose=False, ext=".bin"): points3D.bin/txt Args: - colmap_path (str): Path to the main folder containing images/ and sparse/ subfolders - stride (int): Load every nth image (default: 50) + images_path (str): Path to the images folder + sparse_path (str): Path to the sparse COLMAP folder + stride (int): Load every nth image (default: 1) verbose (bool): Print progress messages ext (str): COLMAP file extension (".bin" or ".txt") Returns: list: List of view dictionaries for MapAnything inference """ - # Define paths - images_folder = os.path.join(colmap_path, "images") - sparse_folder = os.path.join(colmap_path, "sparse") - # Check that required folders exist - if not os.path.exists(images_folder): - raise ValueError(f"Required folder 'images' not found at: {images_folder}") - if not os.path.exists(sparse_folder): - raise ValueError(f"Required folder 'sparse' not found at: {sparse_folder}") + if not os.path.exists(images_path): + raise ValueError(f"Images folder not found at: {images_path}") + if not os.path.exists(sparse_path): + raise ValueError(f"Sparse folder not found at: {sparse_path}") if verbose: - print(f"Loading COLMAP data from: {colmap_path}") - print(f"Images folder: {images_folder}") - print(f"Sparse folder: {sparse_folder}") + print(f"Images folder: {images_path}") + print(f"Sparse folder: {sparse_path}") print(f"Using COLMAP file extension: {ext}") # Read COLMAP model try: - cameras, images_colmap, points3D = read_model(sparse_folder, ext=ext) + cameras, images_colmap, points3D = read_model(sparse_path, ext=ext) except Exception as e: - raise ValueError(f"Failed to read COLMAP model from {sparse_folder}: {e}") + raise ValueError(f"Failed to read COLMAP model from {sparse_path}: {e}") if verbose: print( @@ -90,26 +91,28 @@ def load_colmap_data(colmap_path, stride=1, verbose=False, ext=".bin"): # Get list of available image files available_images = set() - for f in os.listdir(images_folder): + for f in os.listdir(images_path): if f.lower().endswith((".jpg", ".jpeg", ".png")): available_images.add(f) if not available_images: - raise ValueError(f"No image files found in {images_folder}") + raise ValueError(f"No image files found in {images_path}") views_example = [] + image_names_in_order = [] # Track image names in the order they're added to views processed_count = 0 # Get a list of all colmap image names colmap_image_names = set(img_info.name for img_info in images_colmap.values()) # Find the unposed images (in images/ but not in colmap) unposed_images = available_images - colmap_image_names + unposed_images = sorted(list(unposed_images)) if verbose: print(f"Found {len(unposed_images)} images without COLMAP poses") # Process images in COLMAP order - for img_id, img_info in images_colmap.items(): + for img_id, img_info in sorted(images_colmap.items()): # Apply stride if processed_count % stride != 0: processed_count += 1 @@ -118,7 +121,7 @@ def load_colmap_data(colmap_path, stride=1, verbose=False, ext=".bin"): img_name = img_info.name # Check if image file exists - image_path = os.path.join(images_folder, img_name) + image_path = os.path.join(images_path, img_name) if not os.path.exists(image_path): if verbose: print(f"Warning: Image file not found for {img_name}, skipping") @@ -165,6 +168,7 @@ def load_colmap_data(colmap_path, stride=1, verbose=False, ext=".bin"): } views_example.append(view) + image_names_in_order.append(img_name) processed_count += 1 if verbose: @@ -188,7 +192,7 @@ def load_colmap_data(colmap_path, stride=1, verbose=False, ext=".bin"): processed_count += 1 continue - image_path = os.path.join(images_folder, img_name) + image_path = os.path.join(images_path, img_name) try: # Load image @@ -204,6 +208,7 @@ def load_colmap_data(colmap_path, stride=1, verbose=False, ext=".bin"): } views_example.append(view) + image_names_in_order.append(img_name) processed_count += 1 if verbose: @@ -224,7 +229,7 @@ def load_colmap_data(colmap_path, stride=1, verbose=False, ext=".bin"): if verbose: print(f"Successfully loaded {len(views_example)} views with stride={stride}") - return views_example + return views_example, image_names_in_order def log_data_to_rerun( @@ -274,10 +279,22 @@ def get_parser(): description="MapAnything demo using COLMAP reconstructions as input" ) parser.add_argument( - "--colmap_path", + "--images_dir", + type=str, + required=True, + help="Path to directory containing input images", + ) + parser.add_argument( + "--sparse_dir", type=str, required=True, - help="Path to folder containing images/ and sparse/ subfolders", + help="Path to COLMAP sparse reconstruction directory", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Path to directory where outputs will be saved", ) parser.add_argument( "--stride", @@ -321,18 +338,18 @@ def get_parser(): default=False, help="Save reconstruction as GLB file", ) - parser.add_argument( - "--output_directory", - type=str, - default="colmap_mapanything_output", - help="Output directory for GLB file and input images", - ) parser.add_argument( "--save_input_images", action="store_true", default=False, help="Save input images alongside GLB output (requires --save_glb)", ) + parser.add_argument( + "--save_colmap", + action="store_true", + default=True, + help="Save reconstruction in COLMAP format", + ) parser.add_argument( "--ignore_calibration_inputs", action="store_true", @@ -348,177 +365,217 @@ def get_parser(): return parser +def create_pixel_coordinate_grid(num_frames, height, width): + """ + Creates a grid of pixel coordinates and frame indices for all frames. + Returns: + tuple: A tuple containing: + - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3) + with x, y coordinates and frame indices + - y_coords (numpy.ndarray): Array of y coordinates for all frames + - x_coords (numpy.ndarray): Array of x coordinates for all frames + - f_coords (numpy.ndarray): Array of frame indices for all frames + """ + # Create coordinate grids for a single frame + y_grid, x_grid = np.indices((height, width), dtype=np.float32) + x_grid = x_grid[np.newaxis, :, :] + y_grid = y_grid[np.newaxis, :, :] -def main(): - # Parser for arguments and Rerun - parser = get_parser() - script_add_rerun_args( - parser - ) # Options: --headless, --connect, --serve, --addr, --save, --stdout - args = parser.parse_args() + # Broadcast to all frames + x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) + y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) - # Get inference device - device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"Using device: {device}") + # Create frame indices and broadcast + f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis] + f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) - # Initialize model from HuggingFace - if args.apache: - model_name = "facebook/map-anything-apache" - print("Loading Apache 2.0 licensed MapAnything model...") - else: - model_name = "facebook/map-anything" - print("Loading CC-BY-NC 4.0 licensed MapAnything model...") - model = MapAnything.from_pretrained(model_name).to(device) + # Stack coordinates and frame indices + points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) - # Load COLMAP data - print(f"Loading COLMAP data from: {args.colmap_path}") - views_example = load_colmap_data( - args.colmap_path, - stride=args.stride, - verbose=args.verbose, - ext=args.ext, - ) - print(f"Loaded {len(views_example)} views") + return points_xyf + +def rename_colmap_recons( + reconstruction, + image_paths, +): + """Rename COLMAP reconstruction images to original names.""" + for pyimageid in reconstruction.images: + # Reshaped the padded & resized image to the original size + # Rename the images to the original names + pyimage = reconstruction.images[pyimageid] + pyimage.name = image_paths[pyimageid - 1] + + return reconstruction + + +def process_scene(model, images_path, sparse_path, output_dir, args, scene_name=None): + """Process a single scene and save outputs.""" + print(f"\n{'='*60}") + print(f"Processing scene: {scene_name}") + print(f"{'='*60}") - # Preprocess inputs to the expected format + try: + views_example, image_names_in_order = load_colmap_data( + images_path, + sparse_path, + stride=args.stride, + verbose=args.verbose, + ext=args.ext, + ) + print(f"Loaded {len(views_example)} views") + except Exception as e: + print(f"ERROR - Failed to load COLMAP data: {e}") + return False + + # Preprocess inputs print("Preprocessing COLMAP inputs...") processed_views = preprocess_inputs(views_example, verbose=False) # Run model inference - print("Running MapAnything inference on COLMAP data...") + print("Running MapAnything inference...") outputs = model.infer( processed_views, memory_efficient_inference=args.memory_efficient_inference, - # Control which COLMAP inputs to use/ignore - ignore_calibration_inputs=args.ignore_calibration_inputs, # Whether to use COLMAP calibration or not - ignore_depth_inputs=True, # COLMAP doesn't provide depth (can recover from sparse points but convoluted) - ignore_pose_inputs=args.ignore_pose_inputs, # Whether to use COLMAP poses or not - ignore_depth_scale_inputs=True, # No depth data - ignore_pose_scale_inputs=True, # COLMAP poses are non-metric - # Use amp for better performance + ignore_calibration_inputs=args.ignore_calibration_inputs, + ignore_depth_inputs=True, + ignore_pose_inputs=args.ignore_pose_inputs, + ignore_depth_scale_inputs=True, + ignore_pose_scale_inputs=True, use_amp=True, amp_dtype="bf16", apply_mask=True, mask_edges=True, ) - print("COLMAP inference complete!") - - # Prepare lists for GLB export if needed - world_points_list = [] - images_list = [] - masks_list = [] - - # Initialize Rerun if visualization is enabled - if args.viz: - print("Starting visualization...") - viz_string = "MapAnything_COLMAP_Inference_Visualization" - rr.script_setup(args, viz_string) - rr.set_time("stable_time", sequence=0) - rr.log("mapanything", rr.ViewCoordinates.RDF, static=True) - - # Loop through the outputs - for view_idx, pred in enumerate(outputs): - # Extract data from predictions - depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W) - intrinsics_torch = pred["intrinsics"][0] # (3, 3) - camera_pose_torch = pred["camera_poses"][0] # (4, 4) - - # Compute new pts3d using depth, intrinsics, and camera pose - pts3d_computed, valid_mask = depthmap_to_world_frame( - depthmap_torch, intrinsics_torch, camera_pose_torch + print("Inference complete!") + + if args.save_colmap: + intrinsic = np.stack([outputs[i]["intrinsics"][0].cpu().numpy() for i in range(len(outputs))]) + extrinsic = np.stack([closed_form_pose_inverse(outputs[i]["camera_poses"])[0].cpu().numpy() for i in range(len(outputs))]) + points_3d = np.stack([outputs[i]["pts3d"][0].cpu().numpy() for i in range(len(outputs))]) + depth_conf = np.stack([outputs[i]["conf"][0].cpu().numpy() for i in range(len(outputs))]) + images = np.stack([processed_views[i]["img"][0].cpu().numpy() for i in range(len(processed_views))]) + + shared_camera = ( + False # in the feedforward manner, we do not support shared camera + ) + camera_type = ( + "PINHOLE" # in the feedforward manner, we only support PINHOLE camera ) - # Convert to numpy arrays - mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) - mask = mask & valid_mask.cpu().numpy() # Combine with valid depth mask - pts3d_np = pts3d_computed.cpu().numpy() - image_np = pred["img_no_norm"][0].cpu().numpy() - - # Store data for GLB export if needed - if args.save_glb: - world_points_list.append(pts3d_np) - images_list.append(image_np) - masks_list.append(mask) - - # Log to Rerun if visualization is enabled - if args.viz: - log_data_to_rerun( - image=image_np, - depthmap=depthmap_torch.cpu().numpy(), - pose=camera_pose_torch.cpu().numpy(), - intrinsics=intrinsics_torch.cpu().numpy(), - pts3d=pts3d_np, - mask=mask, - base_name=f"mapanything/view_{view_idx}", - pts_name=f"mapanything/pointcloud_view_{view_idx}", - viz_mask=mask, - ) + num_frames, height, width, _ = points_3d.shape + + image_size = np.array([width, height]) + + # Denormalize images before computing RGB values + points_rgb_images = F.interpolate( + torch.from_numpy(images).to(torch.float32), + size=(height, width), + mode="bilinear", + align_corners=False, + ) - # Convey that the visualization is complete - if args.viz: - print("Visualization complete! Check the Rerun viewer.") + # Convert normalized images back to RGB [0,1] range using the rgb function + points_rgb_list = [] + for i in range(points_rgb_images.shape[0]): + # rgb function expects single image tensor and returns numpy array in [0,1] range + rgb_img = rgb(points_rgb_images[i], model.encoder.data_norm_type) + points_rgb_list.append(rgb_img) + + # Stack and convert to uint8 + points_rgb = np.stack(points_rgb_list) # Shape: (N, H, W, 3) + points_rgb = (points_rgb * 255).astype(np.uint8) + + # (S, H, W, 3), with x, y coordinates and frame indices + points_xyf = create_pixel_coordinate_grid(num_frames, height, width) + + # Filter points based on depth validity and confidence threshold + # Use confidence threshold to filter out low-quality predictions + # conf_threshold = 1.5 # Adjust this value based on quality needs (typical range: 1.0-3.0) + # valid_mask = (points_3d[..., 2] > 0) & (depth_conf >= conf_threshold) + valid_mask = (points_3d[..., 2] > 0) + + points_3d = points_3d[valid_mask] + points_xyf = points_xyf[valid_mask] + points_rgb = points_rgb[valid_mask] + + print("Converting to COLMAP format") + reconstruction = batch_np_matrix_to_pycolmap_wo_track( + points_3d, + points_xyf, + points_rgb, + extrinsic, + intrinsic, + image_size, + shared_camera=shared_camera, + camera_type=camera_type, + ) - # Export GLB if requested - if args.save_glb: - # Create output directory structure - scene_output_dir = args.output_directory - os.makedirs(scene_output_dir, exist_ok=True) - scene_prefix = os.path.basename(scene_output_dir) + reconstruction = rename_colmap_recons( + reconstruction, + image_names_in_order, + ) - glb_output_path = os.path.join( - scene_output_dir, f"{scene_prefix}_mapanything_colmap_output.glb" + print(f"Saving reconstruction to {output_dir}/sparse") + sparse_reconstruction_dir = os.path.join(output_dir, "sparse") + os.makedirs(sparse_reconstruction_dir, exist_ok=True) + reconstruction.write(sparse_reconstruction_dir) + + # Save point cloud for fast visualization + trimesh.PointCloud(points_3d, colors=points_rgb).export( + os.path.join(output_dir, "sparse/points.ply") ) - print(f"Saving GLB file to: {glb_output_path}") + return True - # Save processed input images if requested - if args.save_input_images: - # Create processed images directory - processed_images_dir = os.path.join( - scene_output_dir, f"{scene_prefix}_input_images" - ) - os.makedirs(processed_images_dir, exist_ok=True) - print(f"Saving processed input images to: {processed_images_dir}") - - # Save each processed input image from outputs - for view_idx, pred in enumerate(outputs): - # Get processed image (RGB, 0-255) - processed_image = ( - pred["img_no_norm"][0].cpu().numpy() * 255 - ) # (H, W, 3) - - # Convert to PIL Image and save as PNG - img_pil = Image.fromarray(processed_image.astype(np.uint8)) - img_path = os.path.join(processed_images_dir, f"view_{view_idx}.png") - img_pil.save(img_path) - - print( - f"Saved {len(outputs)} processed input images to: {processed_images_dir}" - ) - # Stack all views - world_points = np.stack(world_points_list, axis=0) - images = np.stack(images_list, axis=0) - final_masks = np.stack(masks_list, axis=0) - - # Create predictions dict for GLB export - predictions = { - "world_points": world_points, - "images": images, - "final_masks": final_masks, - } - - # Convert to GLB scene - scene_3d = predictions_to_glb(predictions, as_mesh=True) - - # Save GLB file - scene_3d.export(glb_output_path) - print(f"Successfully saved GLB file: {glb_output_path}") - print(f"All outputs saved to: {scene_output_dir}") - else: - print("Skipping GLB export (--save_glb not specified)") - if args.save_input_images: - print("Warning: --save_input_images has no effect without --save_glb") +def main(): + # Parser for arguments + parser = get_parser() + args = parser.parse_args() + script_add_rerun_args(parser) + + # Get inference device + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + # Initialize model from HuggingFace + if args.apache: + model_name = "facebook/map-anything-apache" + print("Loading Apache 2.0 licensed MapAnything model...") + else: + model_name = "facebook/map-anything" + print("Loading CC-BY-NC 4.0 licensed MapAnything model...") + model = MapAnything.from_pretrained(model_name).to(device) + model.eval() + print("✅ Successfully loaded model") + + # Validate input directories + if not os.path.isdir(args.images_dir): + raise ValueError(f"Images directory not found: {args.images_dir}") + if not os.path.isdir(args.sparse_dir): + raise ValueError(f"Sparse directory not found: {args.sparse_dir}") + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Process the scene + try: + with torch.no_grad(): + success = process_scene( + model, + args.images_dir, + args.sparse_dir, + args.output_dir, + args, + scene_name=os.path.basename(args.images_dir) + ) + if success: + print(f"\n✅ Successfully processed scene") + else: + print(f"\n❌ Processing failed") + except Exception as e: + print(f"\n❌ Error processing scene: {e}") + import traceback + traceback.print_exc() if __name__ == "__main__": main()