From 8b30dfae54111272143289a48ab7904548bfe69d Mon Sep 17 00:00:00 2001 From: Samrat Thapa Date: Fri, 22 Aug 2025 14:20:45 +0900 Subject: [PATCH 1/6] fixed deployment logic --- projects/BEVFusion/README.md | 4 +- projects/BEVFusion/bevfusion/bevfusion.py | 45 +++++---- projects/BEVFusion/bevfusion/depth_lss.py | 28 +++--- ...fusion_camera_backbone_tensorrt_dynamic.py | 85 +---------------- ...n_main_body_lidar_only_tensorrt_dynamic.py | 4 +- ...n_main_body_with_image_tensorrt_dynamic.py | 70 +++++++++++++- projects/BEVFusion/deploy/containers.py | 93 +++++++++---------- projects/BEVFusion/deploy/torch2onnx.py | 76 +++++++-------- 8 files changed, 195 insertions(+), 210 deletions(-) diff --git a/projects/BEVFusion/README.md b/projects/BEVFusion/README.md index 935bbe68..a196d312 100644 --- a/projects/BEVFusion/README.md +++ b/projects/BEVFusion/README.md @@ -131,7 +131,7 @@ python projects/BEVFusion/deploy/torch2onnx.py \ ${MODEL_CFG} \ ${CHECKPOINT_PATH} \ --device cuda:0 \ - --work-dir ${WORK_DIR} + --work-dir ${WORK_DIR} \ --module main_body @@ -140,7 +140,7 @@ python projects/BEVFusion/deploy/torch2onnx.py \ ${MODEL_CFG} \ ${CHECKPOINT_PATH} \ --device cuda:0 \ - --work-dir ${WORK_DIR} + --work-dir ${WORK_DIR} \ --module image_backbone ``` diff --git a/projects/BEVFusion/bevfusion/bevfusion.py b/projects/BEVFusion/bevfusion/bevfusion.py index 57b962fc..7b9bc7b6 100644 --- a/projects/BEVFusion/bevfusion/bevfusion.py +++ b/projects/BEVFusion/bevfusion/bevfusion.py @@ -57,7 +57,7 @@ def __init__( self.init_weights() - def _forward(self, batch_inputs_dict: Tensor, batch_data_samples: OptSampleList = None, **kwargs): + def _forward(self, batch_inputs_dict: Tensor, batch_data_samples: OptSampleList = [], using_image_features=False,**kwargs): """Network forward process. Usually includes backbone, neck and head forward without any post- @@ -66,7 +66,7 @@ def _forward(self, batch_inputs_dict: Tensor, batch_data_samples: OptSampleList # NOTE(knzo25): this is used during onnx export batch_input_metas = [item.metainfo for item in batch_data_samples] - feats = self.extract_feat(batch_inputs_dict, batch_input_metas) + feats = self.extract_feat(batch_inputs_dict, batch_input_metas,using_image_features) if self.with_bbox_head: outputs = self.bbox_head(feats, batch_input_metas) @@ -122,6 +122,21 @@ def with_seg_head(self): """bool: Whether the detector has a segmentation head.""" return hasattr(self, "seg_head") and self.seg_head is not None + def get_image_backbone_features(self, x: torch.Tensor) -> torch.Tensor: + B, N, C, H, W = x.size() + x = x.view(B * N, C, H, W).contiguous() + + x = self.img_backbone(x) + x = self.img_neck(x) + + if not isinstance(x, torch.Tensor): + x = x[0] + + BN, C, H, W = x.size() + assert BN == B * N, (BN, B * N) + x = x.view(B, N, C, H, W) + return x + def extract_img_feat( self, x, @@ -136,19 +151,11 @@ def extract_img_feat( img_aug_matrix_inverse=None, lidar_aug_matrix_inverse=None, geom_feats=None, + using_image_features=False, ) -> torch.Tensor: - B, N, C, H, W = x.size() - x = x.view(B * N, C, H, W).contiguous() - - x = self.img_backbone(x) - x = self.img_neck(x) - if not isinstance(x, torch.Tensor): - x = x[0] - - BN, C, H, W = x.size() - assert BN == B * N, (BN, B * N) - x = x.view(B, N, C, H, W) + if not using_image_features: + x = self.get_image_backbone_features(x) with torch.cuda.amp.autocast(enabled=False): # with torch.autocast(device_type='cuda', dtype=torch.float32): @@ -219,7 +226,8 @@ def voxelize(self, points): return feats, coords, sizes def predict( - self, batch_inputs_dict: Dict[str, Optional[Tensor]], batch_data_samples: List[Det3DDataSample], **kwargs + self, batch_inputs_dict: Dict[str, Optional[Tensor]], batch_data_samples: List[Det3DDataSample], + using_image_features=False, **kwargs ) -> List[Det3DDataSample]: """Forward of testing. @@ -246,7 +254,7 @@ def predict( contains a tensor with shape (num_instances, 7). """ batch_input_metas = [item.metainfo for item in batch_data_samples] - feats = self.extract_feat(batch_inputs_dict, batch_input_metas) + feats = self.extract_feat(batch_inputs_dict, batch_input_metas, using_image_features) if self.with_bbox_head: outputs = self.bbox_head.predict(feats, batch_input_metas) @@ -259,6 +267,7 @@ def extract_feat( self, batch_inputs_dict, batch_input_metas, + using_image_features, **kwargs, ): imgs = batch_inputs_dict.get("imgs", None) @@ -290,6 +299,7 @@ def extract_feat( img_aug_matrix, lidar_aug_matrix, batch_input_metas, + using_image_features=using_image_features ) features.append(img_feature) elif imgs is not None: @@ -323,6 +333,7 @@ def extract_feat( lidar_aug_matrix, batch_input_metas, geom_feats=geom_feats, + using_image_features=using_image_features ) features.append(img_feature) @@ -346,10 +357,10 @@ def extract_feat( return x def loss( - self, batch_inputs_dict: Dict[str, Optional[Tensor]], batch_data_samples: List[Det3DDataSample], **kwargs + self, batch_inputs_dict: Dict[str, Optional[Tensor]], batch_data_samples: List[Det3DDataSample], using_image_features:bool = False, **kwargs ) -> List[Det3DDataSample]: batch_input_metas = [item.metainfo for item in batch_data_samples] - feats = self.extract_feat(batch_inputs_dict, batch_input_metas) + feats = self.extract_feat(batch_inputs_dict, batch_input_metas,using_image_features) losses = dict() if self.with_bbox_head: diff --git a/projects/BEVFusion/bevfusion/depth_lss.py b/projects/BEVFusion/bevfusion/depth_lss.py index 8b7afe61..cf74bf74 100644 --- a/projects/BEVFusion/bevfusion/depth_lss.py +++ b/projects/BEVFusion/bevfusion/depth_lss.py @@ -314,20 +314,6 @@ def forward( lidar_aug_matrix_inverse, geom_feats_precomputed, ): - post_trans = img_aug_matrix[..., :3, 3] - camera2lidar_rots = camera2lidar[..., :3, :3] - camera2lidar_trans = camera2lidar[..., :3, 3] - - if camera_intrinsics_inverse is None: - intrins_inverse = torch.inverse(cam_intrinsic)[..., :3, :3] - else: - intrins_inverse = camera_intrinsics_inverse[..., :3, :3] - - if img_aug_matrix_inverse is None: - post_rots_inverse = torch.inverse(img_aug_matrix)[..., :3, :3] - else: - img_aug_matrix_inverse = img_aug_matrix_inverse[..., :3, :3] - if lidar_aug_matrix_inverse is None: lidar_aug_matrix_inverse = torch.inverse(lidar_aug_matrix) @@ -406,6 +392,20 @@ def forward( x = self.bev_pool_precomputed(x, geom_feats, kept, ranks, indices) else: + post_trans = img_aug_matrix[..., :3, 3] + camera2lidar_rots = camera2lidar[..., :3, :3] + camera2lidar_trans = camera2lidar[..., :3, 3] + + if camera_intrinsics_inverse is None: + intrins_inverse = torch.inverse(cam_intrinsic)[..., :3, :3] + else: + intrins_inverse = camera_intrinsics_inverse[..., :3, :3] + + if img_aug_matrix_inverse is None: + post_rots_inverse = torch.inverse(img_aug_matrix)[..., :3, :3] + else: + post_rots_inverse = img_aug_matrix_inverse[..., :3, :3] + geom = self.get_geometry( camera2lidar_rots, camera2lidar_trans, diff --git a/projects/BEVFusion/configs/deploy/bevfusion_camera_backbone_tensorrt_dynamic.py b/projects/BEVFusion/configs/deploy/bevfusion_camera_backbone_tensorrt_dynamic.py index f23aa97b..6e1dacbe 100644 --- a/projects/BEVFusion/configs/deploy/bevfusion_camera_backbone_tensorrt_dynamic.py +++ b/projects/BEVFusion/configs/deploy/bevfusion_camera_backbone_tensorrt_dynamic.py @@ -9,41 +9,15 @@ allow_failed_imports=False, ) +image_dims = (384,576) + backend_config = dict( type="tensorrt", common_config=dict(max_workspace_size=1 << 32), model_inputs=[ dict( input_shapes=dict( - imgs=dict(min_shape=[1, 3, 256, 704], opt_shape=[6, 3, 256, 704], max_shape=[6, 3, 256, 704]), - points=dict(min_shape=[5000, 5], opt_shape=[50000, 5], max_shape=[200000, 5]), - lidar2image=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), - cam2image_inverse=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), - camera2lidar=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), - img_aug_matrix=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), - img_aug_matrix_inverse=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), - lidar_aug_matrix=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), - lidar_aug_matrix_inverse=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), - geom_feats=dict( - min_shape=[0 * 118 * 32 * 88, 4], - opt_shape=[6 * 118 * 32 * 88 // 2, 4], - max_shape=[6 * 118 * 32 * 88, 4], - ), - kept=dict( - min_shape=[0 * 118 * 32 * 88], - opt_shape=[6 * 118 * 32 * 88], - max_shape=[6 * 118 * 32 * 88], - ), - ranks=dict( - min_shape=[0 * 118 * 32 * 88], - opt_shape=[6 * 118 * 32 * 88 // 2], - max_shape=[6 * 118 * 32 * 88], - ), - indices=dict( - min_shape=[0 * 118 * 32 * 88], - opt_shape=[6 * 118 * 32 * 88 // 2], - max_shape=[6 * 118 * 32 * 88], - ), + imgs=dict(min_shape=[1, 3, image_dims[0], image_dims[1]], opt_shape=[6, 3, image_dims[0], image_dims[1]], max_shape=[6, 3, image_dims[0], image_dims[1]]), ) ) ], @@ -57,64 +31,13 @@ save_file="image_backbone.onnx", input_names=[ "imgs", - "points", - "lidar2image", - "cam2image", - "cam2image_inverse", - "camera2lidar", - "img_aug_matrix", - "img_aug_matrix_inverse", - "lidar_aug_matrix", - "lidar_aug_matrix_inverse", - "geom_feats", - "kept", - "ranks", - "indices", ], output_names=["image_feats"], dynamic_axes={ "imgs": { 0: "num_imgs", }, - "points": { - 0: "num_points", - }, - "lidar2image": { - 0: "num_imgs", - }, - "cam2image": { - 0: "num_imgs", - }, - "cam2image_inverse": { - 0: "num_imgs", - }, - "camera2lidar": { - 0: "num_imgs", - }, - "img_aug_matrix": { - 0: "num_imgs", - }, - "img_aug_matrix_inverse": { - 0: "num_imgs", - }, - "lidar_aug_matrix": { - 0: "num_imgs", - }, - "lidar_aug_matrix_inverse": { - 0: "num_imgs", - }, - "geom_feats": { - 0: "num_kept", - }, - "kept": { - 0: "num_geom_feats", - }, - "ranks": { - 0: "num_kept", - }, - "indices": { - 0: "num_kept", - }, + }, input_shape=None, verbose=True, diff --git a/projects/BEVFusion/configs/deploy/bevfusion_main_body_lidar_only_tensorrt_dynamic.py b/projects/BEVFusion/configs/deploy/bevfusion_main_body_lidar_only_tensorrt_dynamic.py index f15e586c..4524ff2d 100644 --- a/projects/BEVFusion/configs/deploy/bevfusion_main_body_lidar_only_tensorrt_dynamic.py +++ b/projects/BEVFusion/configs/deploy/bevfusion_main_body_lidar_only_tensorrt_dynamic.py @@ -15,8 +15,8 @@ model_inputs=[ dict( input_shapes=dict( - voxels=dict(min_shape=[1, 5], opt_shape=[64000, 5], max_shape=[256000, 5]), - coors=dict(min_shape=[1, 4], opt_shape=[64000, 4], max_shape=[256000, 4]), + voxels=dict(min_shape=[1, 10, 4], opt_shape=[64000, 10, 4], max_shape=[256000, 10, 4]), + coors=dict(min_shape=[1, 3], opt_shape=[64000, 3], max_shape=[256000, 3]), num_points_per_voxel=dict(min_shape=[1], opt_shape=[64000], max_shape=[256000]), ) ) diff --git a/projects/BEVFusion/configs/deploy/bevfusion_main_body_with_image_tensorrt_dynamic.py b/projects/BEVFusion/configs/deploy/bevfusion_main_body_with_image_tensorrt_dynamic.py index 1f312fe7..4a519cb8 100644 --- a/projects/BEVFusion/configs/deploy/bevfusion_main_body_with_image_tensorrt_dynamic.py +++ b/projects/BEVFusion/configs/deploy/bevfusion_main_body_with_image_tensorrt_dynamic.py @@ -9,6 +9,9 @@ allow_failed_imports=False, ) +depth_bins = 118 +feature_dims = (48,72) + backend_config = dict( type="tensorrt", common_config=dict(max_workspace_size=1 << 32), @@ -18,7 +21,34 @@ voxels=dict(min_shape=[1, 10, 4], opt_shape=[64000, 10, 4], max_shape=[256000, 10, 4]), coors=dict(min_shape=[1, 3], opt_shape=[64000, 3], max_shape=[256000, 3]), num_points_per_voxel=dict(min_shape=[1], opt_shape=[64000], max_shape=[256000]), - image_feats=dict(min_shape=[80, 180, 180], opt_shape=[80, 180, 180], max_shape=[80, 180, 180]), + # TODO(TIERIV): Optimize. Now, using points will increase latency significantly + # points=dict(min_shape=[5000, 4], opt_shape=[50000, 4], max_shape=[200000, 4]), + lidar2image=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), + img_aug_matrix=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), + geom_feats=dict( + min_shape=[0 * depth_bins * feature_dims[0] * feature_dims[1], 4], + opt_shape=[6 * depth_bins * feature_dims[0] * feature_dims[1] // 2, 4], + max_shape=[6 * depth_bins * feature_dims[0] * feature_dims[1], 4], + ), + kept=dict( + min_shape=[0 * depth_bins * feature_dims[0] * feature_dims[1]], + opt_shape=[6 * depth_bins * feature_dims[0] * feature_dims[1]], + max_shape=[6 * depth_bins * feature_dims[0] * feature_dims[1]], + ), + ranks=dict( + min_shape=[0 * depth_bins * feature_dims[0] * feature_dims[1]], + opt_shape=[6 * depth_bins * feature_dims[0] * feature_dims[1] // 2], + max_shape=[6 * depth_bins * feature_dims[0] * feature_dims[1]], + ), + indices=dict( + min_shape=[0 * depth_bins * feature_dims[0] * feature_dims[1]], + opt_shape=[6 * depth_bins * feature_dims[0] * feature_dims[1] // 2], + max_shape=[6 * depth_bins * feature_dims[0] * feature_dims[1]], + ), + image_feats=dict( + min_shape=[0, 256, feature_dims[0], feature_dims[1]], + opt_shape=[6, 256, feature_dims[0], feature_dims[1]], + max_shape=[6, 256, feature_dims[0], feature_dims[1]]), ) ) ], @@ -30,7 +60,19 @@ keep_initializers_as_inputs=False, opset_version=17, save_file="main_body.onnx", - input_names=["voxels", "coors", "num_points_per_voxel", "image_feats"], + input_names=[ + "voxels", + "coors", + "num_points_per_voxel", + # "points", + "lidar2image", + "img_aug_matrix", + "geom_feats", + "kept", + "ranks", + "indices", + "image_feats", + ], output_names=["bbox_pred", "score", "label_pred"], dynamic_axes={ "voxels": { @@ -42,6 +84,30 @@ "num_points_per_voxel": { 0: "voxels_num", }, + # "points": { + # 0: "num_points", + # }, + "lidar2image": { + 0: "num_imgs", + }, + "img_aug_matrix": { + 0: "num_imgs", + }, + "geom_feats": { + 0: "num_kept", + }, + "kept": { + 0: "num_geom_feats", + }, + "ranks": { + 0: "num_kept", + }, + "indices": { + 0: "num_kept", + }, + "image_feats": { + 0: "num_imgs", + }, }, input_shape=None, verbose=True, diff --git a/projects/BEVFusion/deploy/containers.py b/projects/BEVFusion/deploy/containers.py index 97c6e6ff..c67908d8 100644 --- a/projects/BEVFusion/deploy/containers.py +++ b/projects/BEVFusion/deploy/containers.py @@ -11,41 +11,13 @@ def __init__(self, mod, mean, std) -> None: self.images_mean = mean self.images_std = std - def forward( - self, - imgs, - points, - lidar2image, - cam2image, - cam2image_inverse, - camera2lidar, - img_aug_matrix, - img_aug_matrix_inverse, - lidar_aug_matrix, - lidar_aug_matrix_inverse, - geom_feats, - kept, - ranks, - indices, - ): + def forward(self,imgs): mod = self.mod imgs = (imgs - self.images_mean) / self.images_std - return mod.extract_img_feat( - imgs, - points, - lidar2image, - cam2image, - camera2lidar, - img_aug_matrix, - lidar_aug_matrix, - img_metas=None, - img_aug_matrix_inverse=img_aug_matrix_inverse, - camera_intrinsics_inverse=cam2image_inverse, - lidar_aug_matrix_inverse=lidar_aug_matrix_inverse, - geom_feats=(geom_feats, kept, ranks, indices), - ) + # No lidar augmentations expected during inference. + return mod.get_image_backbone_features(imgs) class TrtBevFusionMainContainer(torch.nn.Module): @@ -53,13 +25,20 @@ def __init__(self, mod, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.mod = mod - def forward(self, voxels, coors, num_points_per_voxel, image_feats=None): - mod = self.mod - - features = [] + def forward(self, voxels, + coors, + num_points_per_voxel, + # points = None, + lidar2img = None, + img_aug_matrix = None, + geom_feats = None, + kept = None, + ranks = None, + indices = None, + image_feats = None, + ): - if image_feats is not None: - features.append(image_feats) + mod = self.mod if coors.shape[1] == 3: num_points = coors.shape[0] @@ -67,26 +46,40 @@ def forward(self, voxels, coors, num_points_per_voxel, image_feats=None): batch_coors = torch.zeros(num_points, 1).to(coors.device) coors = torch.cat([batch_coors, coors], dim=1).contiguous() - pts_feature = mod.extract_pts_feat(voxels, coors, num_points_per_voxel) - features.append(pts_feature) - - if mod.fusion_layer is not None: - x = mod.fusion_layer(features) - else: - assert len(features) == 1, features - x = features[0] - x = mod.pts_backbone(x) - x = mod.pts_neck(x) + batch_inputs_dict = { + "voxels": {"voxels": voxels, "coors": coors, "num_points_per_voxel": num_points_per_voxel}, + } - outputs = mod.bbox_head(x, None)[0][0] + if image_feats is not None: + lidar_aug_matrix = torch.eye(4).unsqueeze(0).to(image_feats.device) + + batch_inputs_dict.update( + { + "imgs": image_feats, + "lidar2img": lidar2img, + "cam2img": None, + "cam2lidar": None, + "img_aug_matrix": img_aug_matrix, + "img_aug_matrix_inverse": None, + "lidar_aug_matrix": lidar_aug_matrix, + "lidar_aug_matrix_inverse": lidar_aug_matrix, + "geom_feats": (geom_feats, kept, ranks, indices), + } + ) + + outputs = mod._forward(batch_inputs_dict,using_image_features=True) + + # The following code is taken from + # projects/BEVFusion/bevfusion/bevfusion_head.py + # It is used to simplify the post process in deployment score = outputs["heatmap"].sigmoid() one_hot = F.one_hot(outputs["query_labels"], num_classes=score.size(1)).permute(0, 2, 1) score = score * outputs["query_heatmap_score"] * one_hot score = score[0].max(dim=0)[0] bbox_pred = torch.cat( - [outputs["center"][0], outputs["height"][0], outputs["dim"][0], outputs["rot"][0], outputs["vel"][0]], - dim=0, + [outputs["center"][0], outputs["height"][0], outputs["dim"][0], outputs["rot"][0], outputs["vel"][0]], dim=0 ) + return bbox_pred, score, outputs["query_labels"][0] diff --git a/projects/BEVFusion/deploy/torch2onnx.py b/projects/BEVFusion/deploy/torch2onnx.py index 745dc390..4eac4c9e 100644 --- a/projects/BEVFusion/deploy/torch2onnx.py +++ b/projects/BEVFusion/deploy/torch2onnx.py @@ -101,7 +101,7 @@ def parse_args(): points, camera_mask, imgs, - lidar2image, + lidar2img, cam2image, camera2lidar, geom_feats, @@ -178,29 +178,11 @@ def _add_or_update(cfg: dict, key: str, val: Any): if "img_backbone" in model_cfg.model: img_aug_matrix = imgs.new_tensor(np.stack(data_samples[0].img_aug_matrix)) - lidar_aug_matrix = torch.eye(4).to(imgs.device) images_mean = data_preprocessor.mean.to(device) images_std = data_preprocessor.std.to(device) image_backbone_container = TrtBevFusionImageBackboneContainer(patched_model, images_mean, images_std) - model_inputs = ( - imgs.unsqueeze(0).to(device).float(), - points.unsqueeze(0).to(device).float(), - lidar2image.unsqueeze(0).to(device).float(), - cam2image.unsqueeze(0).to(device).float(), - torch.inverse(cam2image).unsqueeze(0).to(device).float(), - camera2lidar.unsqueeze(0).to(device).float(), - img_aug_matrix.unsqueeze(0).to(device).float(), - imgs.new_tensor(np.stack([np.linalg.inv(x) for x in data_samples[0].img_aug_matrix])) - .unsqueeze(0) - .to(device) - .float(), - lidar_aug_matrix.unsqueeze(0).to(device).float(), - torch.inverse(lidar_aug_matrix).unsqueeze(0).to(device).float(), - geom_feats.to(device).float(), - kept.to(device), - ranks.to(device).long(), - indices.to(device).long(), - ) + model_inputs = (imgs.unsqueeze(0).to(device).float(),) + if args.module == "image_backbone": return_value = torch.onnx.export( image_backbone_container, @@ -225,11 +207,20 @@ def _add_or_update(cfg: dict, key: str, val: Any): num_points_per_voxel.to(device), ) if image_feats is not None: - model_inputs += (image_feats,) + model_inputs += ( + # points.unsqueeze(0).to(device).float(), # TODO(TIERIV): Optimize. Now, using points will increase latency significantly + lidar2img.unsqueeze(0).to(device).float(), + img_aug_matrix.unsqueeze(0).to(device).float(), + geom_feats.to(device).float(), + kept.to(device), + ranks.to(device).long(), + indices.to(device).long(), + image_feats, + ) torch.onnx.export( main_container, model_inputs, - output_path, + output_path.replace(".onnx", "_tofix.onnx"), export_params=True, input_names=input_names, output_names=output_names, @@ -237,29 +228,30 @@ def _add_or_update(cfg: dict, key: str, val: Any): dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs, verbose=verbose, + do_constant_folding=False, ) - logger.info(f"ONNX exported to {output_path}") - logger.info("Attempting to fix the graph (TopK's K becoming a tensor)") + logger.info("Attempting to fix the graph (TopK's K becoming a tensor)") - import onnx_graphsurgeon as gs + import onnx_graphsurgeon as gs - model = onnx.load(output_path) - graph = gs.import_onnx(model) + model = onnx.load(output_path.replace(".onnx", "_tofix.onnx")) + graph = gs.import_onnx(model) - # Fix TopK - topk_nodes = [node for node in graph.nodes if node.op == "TopK"] - assert len(topk_nodes) == 1 - topk = topk_nodes[0] - k = model_cfg.num_proposals - topk.inputs[1] = gs.Constant("K", values=np.array([k], dtype=np.int64)) - topk.outputs[0].shape = [1, k] - topk.outputs[0].dtype = topk.inputs[0].dtype if topk.inputs[0].dtype else np.float32 - topk.outputs[1].shape = [1, k] - topk.outputs[1].dtype = np.int64 + # Fix TopK + topk_nodes = [node for node in graph.nodes if node.op == "TopK"] + assert len(topk_nodes) == 1 + topk = topk_nodes[0] + k = model_cfg.num_proposals + topk.inputs[1] = gs.Constant("K", values=np.array([k], dtype=np.int64)) + topk.outputs[0].shape = [1, k] + topk.outputs[0].dtype = topk.inputs[0].dtype if topk.inputs[0].dtype else np.float32 + topk.outputs[1].shape = [1, k] + topk.outputs[1].dtype = np.int64 - graph.cleanup().toposort() - output_path = output_path.replace(".onnx", "_fixed.onnx") - onnx.save_model(gs.export_onnx(graph), output_path) + graph.cleanup().toposort() + onnx.save_model(gs.export_onnx(graph), output_path) - logger.info(f"(Fixed) ONNX exported to {output_path}") + logger.info(f"(Fixed) ONNX exported to {output_path}") + + logger.info(f"ONNX exported to {output_path}") From 58dbec6e09ede60d9e62bea9edbc47eb52eeab80 Mon Sep 17 00:00:00 2001 From: Samrat Thapa Date: Fri, 22 Aug 2025 15:49:55 +0900 Subject: [PATCH 2/6] added points to cl inference --- projects/BEVFusion/bevfusion/bevfusion.py | 24 +++++++++---------- ...n_main_body_with_image_tensorrt_dynamic.py | 10 ++++---- projects/BEVFusion/deploy/containers.py | 7 +++--- projects/BEVFusion/deploy/torch2onnx.py | 8 +++---- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/projects/BEVFusion/bevfusion/bevfusion.py b/projects/BEVFusion/bevfusion/bevfusion.py index 7b9bc7b6..699d3ed2 100644 --- a/projects/BEVFusion/bevfusion/bevfusion.py +++ b/projects/BEVFusion/bevfusion/bevfusion.py @@ -273,6 +273,8 @@ def extract_feat( imgs = batch_inputs_dict.get("imgs", None) points = batch_inputs_dict.get("points", None) features = [] + + is_onnx_inference = False if imgs is not None and "lidar2img" not in batch_inputs_dict: # NOTE(knzo25): normal training and testing imgs = imgs.contiguous() @@ -304,28 +306,24 @@ def extract_feat( features.append(img_feature) elif imgs is not None: # NOTE(knzo25): onnx inference + is_onnx_inference = True lidar2image = batch_inputs_dict["lidar2img"] camera_intrinsics = batch_inputs_dict["cam2img"] camera2lidar = batch_inputs_dict["cam2lidar"] img_aug_matrix = batch_inputs_dict["img_aug_matrix"] lidar_aug_matrix = batch_inputs_dict["lidar_aug_matrix"] - # NOTE(knzo25): originally BEVFusion uses all the points - # which could be a bit slow. For now I am using only - # the centroids, which is also suboptimal, but using - # all the voxels produce errors in TensorRT, - # so this will be fixed for the next version - # (ScatterElements bug, or simply null voxels break the equation) - feats = batch_inputs_dict["voxels"]["voxels"] - sizes = batch_inputs_dict["voxels"]["num_points_per_voxel"] + geom_feats = batch_inputs_dict["geom_feats"] + + + # feats = batch_inputs_dict["voxels"]["voxels"] + # sizes = batch_inputs_dict["voxels"]["num_points_per_voxel"] + # feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(-1, 1) - feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(-1, 1) - geom_feats = batch_inputs_dict["geom_feats"] img_feature = self.extract_img_feat( imgs, - [feats], - # points, + points, lidar2image, camera_intrinsics, camera2lidar, @@ -341,7 +339,7 @@ def extract_feat( batch_inputs_dict.get("voxels", {}).get("voxels", None), batch_inputs_dict.get("voxels", {}).get("coors", None), batch_inputs_dict.get("voxels", {}).get("num_points_per_voxel", None), - points=points, + points=points if not is_onnx_inference else None, ) features.append(pts_feature) diff --git a/projects/BEVFusion/configs/deploy/bevfusion_main_body_with_image_tensorrt_dynamic.py b/projects/BEVFusion/configs/deploy/bevfusion_main_body_with_image_tensorrt_dynamic.py index 4a519cb8..f122cbf2 100644 --- a/projects/BEVFusion/configs/deploy/bevfusion_main_body_with_image_tensorrt_dynamic.py +++ b/projects/BEVFusion/configs/deploy/bevfusion_main_body_with_image_tensorrt_dynamic.py @@ -22,7 +22,7 @@ coors=dict(min_shape=[1, 3], opt_shape=[64000, 3], max_shape=[256000, 3]), num_points_per_voxel=dict(min_shape=[1], opt_shape=[64000], max_shape=[256000]), # TODO(TIERIV): Optimize. Now, using points will increase latency significantly - # points=dict(min_shape=[5000, 4], opt_shape=[50000, 4], max_shape=[200000, 4]), + points=dict(min_shape=[5000, 4], opt_shape=[50000, 4], max_shape=[200000, 4]), lidar2image=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), img_aug_matrix=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), geom_feats=dict( @@ -64,7 +64,7 @@ "voxels", "coors", "num_points_per_voxel", - # "points", + "points", "lidar2image", "img_aug_matrix", "geom_feats", @@ -84,9 +84,9 @@ "num_points_per_voxel": { 0: "voxels_num", }, - # "points": { - # 0: "num_points", - # }, + "points": { + 0: "num_points", + }, "lidar2image": { 0: "num_imgs", }, diff --git a/projects/BEVFusion/deploy/containers.py b/projects/BEVFusion/deploy/containers.py index c67908d8..f541d056 100644 --- a/projects/BEVFusion/deploy/containers.py +++ b/projects/BEVFusion/deploy/containers.py @@ -28,7 +28,7 @@ def __init__(self, mod, *args, **kwargs) -> None: def forward(self, voxels, coors, num_points_per_voxel, - # points = None, + points = None, lidar2img = None, img_aug_matrix = None, geom_feats = None, @@ -37,9 +37,7 @@ def forward(self, voxels, indices = None, image_feats = None, ): - mod = self.mod - if coors.shape[1] == 3: num_points = coors.shape[0] coors = coors.flip(dims=[-1]).contiguous() # [x, y, z] @@ -50,6 +48,9 @@ def forward(self, voxels, "voxels": {"voxels": voxels, "coors": coors, "num_points_per_voxel": num_points_per_voxel}, } + if points is not None: + batch_inputs_dict["points"] = points + if image_feats is not None: lidar_aug_matrix = torch.eye(4).unsqueeze(0).to(image_feats.device) diff --git a/projects/BEVFusion/deploy/torch2onnx.py b/projects/BEVFusion/deploy/torch2onnx.py index 4eac4c9e..a656abdf 100644 --- a/projects/BEVFusion/deploy/torch2onnx.py +++ b/projects/BEVFusion/deploy/torch2onnx.py @@ -208,19 +208,19 @@ def _add_or_update(cfg: dict, key: str, val: Any): ) if image_feats is not None: model_inputs += ( - # points.unsqueeze(0).to(device).float(), # TODO(TIERIV): Optimize. Now, using points will increase latency significantly + points.unsqueeze(0).to(device).float(), lidar2img.unsqueeze(0).to(device).float(), img_aug_matrix.unsqueeze(0).to(device).float(), geom_feats.to(device).float(), kept.to(device), ranks.to(device).long(), indices.to(device).long(), - image_feats, + image_feats ) torch.onnx.export( main_container, model_inputs, - output_path.replace(".onnx", "_tofix.onnx"), + output_path.replace(".onnx", "_temp_to_be_fixed.onnx"), export_params=True, input_names=input_names, output_names=output_names, @@ -235,7 +235,7 @@ def _add_or_update(cfg: dict, key: str, val: Any): import onnx_graphsurgeon as gs - model = onnx.load(output_path.replace(".onnx", "_tofix.onnx")) + model = onnx.load(output_path.replace(".onnx", "_temp_to_be_fixed.onnx")) graph = gs.import_onnx(model) # Fix TopK From dcf65027b2b805538994da5a5d2be0a6432140ca Mon Sep 17 00:00:00 2001 From: Samrat Thapa Date: Sat, 23 Aug 2025 11:27:22 +0900 Subject: [PATCH 3/6] add points and fix shapes --- projects/BEVFusion/deploy/containers.py | 12 ++++++------ projects/BEVFusion/deploy/torch2onnx.py | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/projects/BEVFusion/deploy/containers.py b/projects/BEVFusion/deploy/containers.py index f541d056..b9977a2a 100644 --- a/projects/BEVFusion/deploy/containers.py +++ b/projects/BEVFusion/deploy/containers.py @@ -14,10 +14,10 @@ def __init__(self, mod, mean, std) -> None: def forward(self,imgs): mod = self.mod - imgs = (imgs - self.images_mean) / self.images_std + imgs = (imgs.unsqueeze(0) - self.images_mean) / self.images_std # No lidar augmentations expected during inference. - return mod.get_image_backbone_features(imgs) + return mod.get_image_backbone_features(imgs)[0] class TrtBevFusionMainContainer(torch.nn.Module): @@ -49,7 +49,7 @@ def forward(self, voxels, } if points is not None: - batch_inputs_dict["points"] = points + batch_inputs_dict["points"] = [points] if image_feats is not None: @@ -57,11 +57,11 @@ def forward(self, voxels, batch_inputs_dict.update( { - "imgs": image_feats, - "lidar2img": lidar2img, + "imgs": image_feats.unsqueeze(0), + "lidar2img": lidar2img.unsqueeze(0), "cam2img": None, "cam2lidar": None, - "img_aug_matrix": img_aug_matrix, + "img_aug_matrix": img_aug_matrix.unsqueeze(0), "img_aug_matrix_inverse": None, "lidar_aug_matrix": lidar_aug_matrix, "lidar_aug_matrix_inverse": lidar_aug_matrix, diff --git a/projects/BEVFusion/deploy/torch2onnx.py b/projects/BEVFusion/deploy/torch2onnx.py index a656abdf..a0aa3db4 100644 --- a/projects/BEVFusion/deploy/torch2onnx.py +++ b/projects/BEVFusion/deploy/torch2onnx.py @@ -181,7 +181,7 @@ def _add_or_update(cfg: dict, key: str, val: Any): images_mean = data_preprocessor.mean.to(device) images_std = data_preprocessor.std.to(device) image_backbone_container = TrtBevFusionImageBackboneContainer(patched_model, images_mean, images_std) - model_inputs = (imgs.unsqueeze(0).to(device).float(),) + model_inputs = (imgs.to(device).float(),) if args.module == "image_backbone": return_value = torch.onnx.export( @@ -208,9 +208,9 @@ def _add_or_update(cfg: dict, key: str, val: Any): ) if image_feats is not None: model_inputs += ( - points.unsqueeze(0).to(device).float(), - lidar2img.unsqueeze(0).to(device).float(), - img_aug_matrix.unsqueeze(0).to(device).float(), + points.to(device).float(), + lidar2img.to(device).float(), + img_aug_matrix.to(device).float(), geom_feats.to(device).float(), kept.to(device), ranks.to(device).long(), From 583d019a0b21a2bbc8176ec450183eb9957819aa Mon Sep 17 00:00:00 2001 From: Samrat Thapa Date: Tue, 26 Aug 2025 14:26:59 +0900 Subject: [PATCH 4/6] updated data types Signed-off-by: Samrat Thapa --- projects/BEVFusion/deploy/containers.py | 2 +- projects/BEVFusion/deploy/torch2onnx.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/BEVFusion/deploy/containers.py b/projects/BEVFusion/deploy/containers.py index b9977a2a..feb27c92 100644 --- a/projects/BEVFusion/deploy/containers.py +++ b/projects/BEVFusion/deploy/containers.py @@ -14,7 +14,7 @@ def __init__(self, mod, mean, std) -> None: def forward(self,imgs): mod = self.mod - imgs = (imgs.unsqueeze(0) - self.images_mean) / self.images_std + imgs = (imgs.float().unsqueeze(0) - self.images_mean) / self.images_std # No lidar augmentations expected during inference. return mod.get_image_backbone_features(imgs)[0] diff --git a/projects/BEVFusion/deploy/torch2onnx.py b/projects/BEVFusion/deploy/torch2onnx.py index a0aa3db4..06436e45 100644 --- a/projects/BEVFusion/deploy/torch2onnx.py +++ b/projects/BEVFusion/deploy/torch2onnx.py @@ -181,7 +181,7 @@ def _add_or_update(cfg: dict, key: str, val: Any): images_mean = data_preprocessor.mean.to(device) images_std = data_preprocessor.std.to(device) image_backbone_container = TrtBevFusionImageBackboneContainer(patched_model, images_mean, images_std) - model_inputs = (imgs.to(device).float(),) + model_inputs = (imgs.to(device=device,dtype=torch.uint8),) if args.module == "image_backbone": return_value = torch.onnx.export( From c890ee15f9a303b10886afa16c5f8c9a7d7ece6f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Sep 2025 01:09:05 +0000 Subject: [PATCH 5/6] ci(pre-commit): autofix --- projects/BEVFusion/bevfusion/bevfusion.py | 27 ++++++++++------- projects/BEVFusion/bevfusion/depth_lss.py | 4 +-- ...fusion_camera_backbone_tensorrt_dynamic.py | 9 ++++-- ...n_main_body_with_image_tensorrt_dynamic.py | 11 +++---- projects/BEVFusion/deploy/containers.py | 29 ++++++++++--------- projects/BEVFusion/deploy/torch2onnx.py | 8 ++--- 6 files changed, 51 insertions(+), 37 deletions(-) diff --git a/projects/BEVFusion/bevfusion/bevfusion.py b/projects/BEVFusion/bevfusion/bevfusion.py index 699d3ed2..f59cbacf 100644 --- a/projects/BEVFusion/bevfusion/bevfusion.py +++ b/projects/BEVFusion/bevfusion/bevfusion.py @@ -57,7 +57,9 @@ def __init__( self.init_weights() - def _forward(self, batch_inputs_dict: Tensor, batch_data_samples: OptSampleList = [], using_image_features=False,**kwargs): + def _forward( + self, batch_inputs_dict: Tensor, batch_data_samples: OptSampleList = [], using_image_features=False, **kwargs + ): """Network forward process. Usually includes backbone, neck and head forward without any post- @@ -66,7 +68,7 @@ def _forward(self, batch_inputs_dict: Tensor, batch_data_samples: OptSampleList # NOTE(knzo25): this is used during onnx export batch_input_metas = [item.metainfo for item in batch_data_samples] - feats = self.extract_feat(batch_inputs_dict, batch_input_metas,using_image_features) + feats = self.extract_feat(batch_inputs_dict, batch_input_metas, using_image_features) if self.with_bbox_head: outputs = self.bbox_head(feats, batch_input_metas) @@ -226,8 +228,11 @@ def voxelize(self, points): return feats, coords, sizes def predict( - self, batch_inputs_dict: Dict[str, Optional[Tensor]], batch_data_samples: List[Det3DDataSample], - using_image_features=False, **kwargs + self, + batch_inputs_dict: Dict[str, Optional[Tensor]], + batch_data_samples: List[Det3DDataSample], + using_image_features=False, + **kwargs, ) -> List[Det3DDataSample]: """Forward of testing. @@ -301,7 +306,7 @@ def extract_feat( img_aug_matrix, lidar_aug_matrix, batch_input_metas, - using_image_features=using_image_features + using_image_features=using_image_features, ) features.append(img_feature) elif imgs is not None: @@ -314,13 +319,11 @@ def extract_feat( lidar_aug_matrix = batch_inputs_dict["lidar_aug_matrix"] geom_feats = batch_inputs_dict["geom_feats"] - # feats = batch_inputs_dict["voxels"]["voxels"] # sizes = batch_inputs_dict["voxels"]["num_points_per_voxel"] # feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(-1, 1) - img_feature = self.extract_img_feat( imgs, points, @@ -331,7 +334,7 @@ def extract_feat( lidar_aug_matrix, batch_input_metas, geom_feats=geom_feats, - using_image_features=using_image_features + using_image_features=using_image_features, ) features.append(img_feature) @@ -355,10 +358,14 @@ def extract_feat( return x def loss( - self, batch_inputs_dict: Dict[str, Optional[Tensor]], batch_data_samples: List[Det3DDataSample], using_image_features:bool = False, **kwargs + self, + batch_inputs_dict: Dict[str, Optional[Tensor]], + batch_data_samples: List[Det3DDataSample], + using_image_features: bool = False, + **kwargs, ) -> List[Det3DDataSample]: batch_input_metas = [item.metainfo for item in batch_data_samples] - feats = self.extract_feat(batch_inputs_dict, batch_input_metas,using_image_features) + feats = self.extract_feat(batch_inputs_dict, batch_input_metas, using_image_features) losses = dict() if self.with_bbox_head: diff --git a/projects/BEVFusion/bevfusion/depth_lss.py b/projects/BEVFusion/bevfusion/depth_lss.py index cf74bf74..10f84d8a 100644 --- a/projects/BEVFusion/bevfusion/depth_lss.py +++ b/projects/BEVFusion/bevfusion/depth_lss.py @@ -395,12 +395,12 @@ def forward( post_trans = img_aug_matrix[..., :3, 3] camera2lidar_rots = camera2lidar[..., :3, :3] camera2lidar_trans = camera2lidar[..., :3, 3] - + if camera_intrinsics_inverse is None: intrins_inverse = torch.inverse(cam_intrinsic)[..., :3, :3] else: intrins_inverse = camera_intrinsics_inverse[..., :3, :3] - + if img_aug_matrix_inverse is None: post_rots_inverse = torch.inverse(img_aug_matrix)[..., :3, :3] else: diff --git a/projects/BEVFusion/configs/deploy/bevfusion_camera_backbone_tensorrt_dynamic.py b/projects/BEVFusion/configs/deploy/bevfusion_camera_backbone_tensorrt_dynamic.py index 6e1dacbe..dadf80ab 100644 --- a/projects/BEVFusion/configs/deploy/bevfusion_camera_backbone_tensorrt_dynamic.py +++ b/projects/BEVFusion/configs/deploy/bevfusion_camera_backbone_tensorrt_dynamic.py @@ -9,7 +9,7 @@ allow_failed_imports=False, ) -image_dims = (384,576) +image_dims = (384, 576) backend_config = dict( type="tensorrt", @@ -17,7 +17,11 @@ model_inputs=[ dict( input_shapes=dict( - imgs=dict(min_shape=[1, 3, image_dims[0], image_dims[1]], opt_shape=[6, 3, image_dims[0], image_dims[1]], max_shape=[6, 3, image_dims[0], image_dims[1]]), + imgs=dict( + min_shape=[1, 3, image_dims[0], image_dims[1]], + opt_shape=[6, 3, image_dims[0], image_dims[1]], + max_shape=[6, 3, image_dims[0], image_dims[1]], + ), ) ) ], @@ -37,7 +41,6 @@ "imgs": { 0: "num_imgs", }, - }, input_shape=None, verbose=True, diff --git a/projects/BEVFusion/configs/deploy/bevfusion_main_body_with_image_tensorrt_dynamic.py b/projects/BEVFusion/configs/deploy/bevfusion_main_body_with_image_tensorrt_dynamic.py index f122cbf2..094c667a 100644 --- a/projects/BEVFusion/configs/deploy/bevfusion_main_body_with_image_tensorrt_dynamic.py +++ b/projects/BEVFusion/configs/deploy/bevfusion_main_body_with_image_tensorrt_dynamic.py @@ -10,7 +10,7 @@ ) depth_bins = 118 -feature_dims = (48,72) +feature_dims = (48, 72) backend_config = dict( type="tensorrt", @@ -26,7 +26,7 @@ lidar2image=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), img_aug_matrix=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]), geom_feats=dict( - min_shape=[0 * depth_bins * feature_dims[0] * feature_dims[1], 4], + min_shape=[0 * depth_bins * feature_dims[0] * feature_dims[1], 4], opt_shape=[6 * depth_bins * feature_dims[0] * feature_dims[1] // 2, 4], max_shape=[6 * depth_bins * feature_dims[0] * feature_dims[1], 4], ), @@ -46,9 +46,10 @@ max_shape=[6 * depth_bins * feature_dims[0] * feature_dims[1]], ), image_feats=dict( - min_shape=[0, 256, feature_dims[0], feature_dims[1]], - opt_shape=[6, 256, feature_dims[0], feature_dims[1]], - max_shape=[6, 256, feature_dims[0], feature_dims[1]]), + min_shape=[0, 256, feature_dims[0], feature_dims[1]], + opt_shape=[6, 256, feature_dims[0], feature_dims[1]], + max_shape=[6, 256, feature_dims[0], feature_dims[1]], + ), ) ) ], diff --git a/projects/BEVFusion/deploy/containers.py b/projects/BEVFusion/deploy/containers.py index feb27c92..ddc0b793 100644 --- a/projects/BEVFusion/deploy/containers.py +++ b/projects/BEVFusion/deploy/containers.py @@ -11,7 +11,7 @@ def __init__(self, mod, mean, std) -> None: self.images_mean = mean self.images_std = std - def forward(self,imgs): + def forward(self, imgs): mod = self.mod imgs = (imgs.float().unsqueeze(0) - self.images_mean) / self.images_std @@ -25,17 +25,19 @@ def __init__(self, mod, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.mod = mod - def forward(self, voxels, - coors, + def forward( + self, + voxels, + coors, num_points_per_voxel, - points = None, - lidar2img = None, - img_aug_matrix = None, - geom_feats = None, - kept = None, - ranks = None, - indices = None, - image_feats = None, + points=None, + lidar2img=None, + img_aug_matrix=None, + geom_feats=None, + kept=None, + ranks=None, + indices=None, + image_feats=None, ): mod = self.mod if coors.shape[1] == 3: @@ -69,7 +71,7 @@ def forward(self, voxels, } ) - outputs = mod._forward(batch_inputs_dict,using_image_features=True) + outputs = mod._forward(batch_inputs_dict, using_image_features=True) # The following code is taken from # projects/BEVFusion/bevfusion/bevfusion_head.py @@ -80,7 +82,8 @@ def forward(self, voxels, score = score[0].max(dim=0)[0] bbox_pred = torch.cat( - [outputs["center"][0], outputs["height"][0], outputs["dim"][0], outputs["rot"][0], outputs["vel"][0]], dim=0 + [outputs["center"][0], outputs["height"][0], outputs["dim"][0], outputs["rot"][0], outputs["vel"][0]], + dim=0, ) return bbox_pred, score, outputs["query_labels"][0] diff --git a/projects/BEVFusion/deploy/torch2onnx.py b/projects/BEVFusion/deploy/torch2onnx.py index 06436e45..f4360c3e 100644 --- a/projects/BEVFusion/deploy/torch2onnx.py +++ b/projects/BEVFusion/deploy/torch2onnx.py @@ -181,8 +181,8 @@ def _add_or_update(cfg: dict, key: str, val: Any): images_mean = data_preprocessor.mean.to(device) images_std = data_preprocessor.std.to(device) image_backbone_container = TrtBevFusionImageBackboneContainer(patched_model, images_mean, images_std) - model_inputs = (imgs.to(device=device,dtype=torch.uint8),) - + model_inputs = (imgs.to(device=device, dtype=torch.uint8),) + if args.module == "image_backbone": return_value = torch.onnx.export( image_backbone_container, @@ -215,7 +215,7 @@ def _add_or_update(cfg: dict, key: str, val: Any): kept.to(device), ranks.to(device).long(), indices.to(device).long(), - image_feats + image_feats, ) torch.onnx.export( main_container, @@ -253,5 +253,5 @@ def _add_or_update(cfg: dict, key: str, val: Any): onnx.save_model(gs.export_onnx(graph), output_path) logger.info(f"(Fixed) ONNX exported to {output_path}") - + logger.info(f"ONNX exported to {output_path}") From f21ec856d56ad62234ae80419c8f742e03bde4bc Mon Sep 17 00:00:00 2001 From: Samrat Thapa Date: Tue, 9 Sep 2025 10:16:37 +0900 Subject: [PATCH 6/6] removed old scripts --- projects/BEVFusion/deploy/base.py | 96 ------------- projects/BEVFusion/deploy/export.py | 212 ---------------------------- 2 files changed, 308 deletions(-) delete mode 100644 projects/BEVFusion/deploy/base.py delete mode 100644 projects/BEVFusion/deploy/export.py diff --git a/projects/BEVFusion/deploy/base.py b/projects/BEVFusion/deploy/base.py deleted file mode 100644 index fda02487..00000000 --- a/projects/BEVFusion/deploy/base.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -# Copied from -# https://github.com/open-mmlab/mmdeploy/blob/v1.3.1/ -# mmdeploy/codebase/mmdet3d/models/base.py -# NOTE(knzo25): This patches the forward method to -# control the inputs and outputs to the network in deployment - -from typing import List, Optional, Tuple - -import numpy as np -import torch -import torch.nn.functional as F -from mmdeploy.core import FUNCTION_REWRITER - - -@FUNCTION_REWRITER.register_rewriter("mmdet3d.models.detectors.Base3DDetector.forward") # noqa: E501 -def basedetector__forward( - self, - # NOTE(knzo25): BEVFusion originally uses the whole se of points - # for the camera branch. For now I will try to use only the voxels - # points: torch.Tensor, - voxels: Optional[torch.Tensor] = None, - coors: Optional[torch.Tensor] = None, - num_points_per_voxel: Optional[torch.Tensor] = None, - imgs: Optional[torch.Tensor] = None, - lidar2image: Optional[torch.Tensor] = None, - # NOTE(knzo25): not used during export - # but needed to comply with the signature - cam2image: Optional[torch.Tensor] = None, - # NOTE(knzo25): not used during export - # but needed to comply with the signature - camera2lidar: Optional[torch.Tensor] = None, - geom_feats: Optional[torch.Tensor] = None, - kept: Optional[torch.Tensor] = None, - ranks: Optional[torch.Tensor] = None, - indices: Optional[torch.Tensor] = None, - data_samples=None, - **kwargs -) -> Tuple[List[torch.Tensor]]: - - # Note(KokSeang): Convert coors from (z, y, z) to (b, x, y, z) - # Downstream sparse ecnoder expects coors in (b, x, y, z) format - if coors.shape[1] == 3: - num_points = coors.shape[0] - coors = coors.flip(dims=[-1]).contiguous() # [x, y, z] - batch_coors = torch.zeros(num_points, 1).to(coors.device) - coors = torch.cat([batch_coors, coors], dim=1).contiguous() - - batch_inputs_dict = { - # 'points': [points], - "voxels": {"voxels": voxels, "coors": coors, "num_points_per_voxel": num_points_per_voxel}, - } - - if imgs is not None: - - # NOTE(knzo25): moved image normalization to the graph - images_mean = kwargs["data_preprocessor"].mean.to(imgs.device) - images_std = kwargs["data_preprocessor"].std.to(imgs.device) - imgs = (imgs.float() - images_mean) / images_std - - # This is actually not used since we also use geom_feats as an input - # However, it is needed to comply with signatures - img_aug_matrix = imgs.new_tensor(np.stack(data_samples[0].img_aug_matrix)) - img_aug_matrix_inverse = imgs.new_tensor(np.stack([np.linalg.inv(x) for x in data_samples[0].img_aug_matrix])) - lidar_aug_matrix = torch.eye(4).to(imgs.device) - - batch_inputs_dict.update( - { - "imgs": imgs.unsqueeze(dim=0), - "lidar2img": lidar2image.unsqueeze(dim=0), - "cam2img": cam2image.unsqueeze(dim=0), - "cam2lidar": camera2lidar.unsqueeze(dim=0), - "img_aug_matrix": img_aug_matrix.unsqueeze(dim=0), - "img_aug_matrix_inverse": img_aug_matrix_inverse.unsqueeze(dim=0), - "lidar_aug_matrix": lidar_aug_matrix.unsqueeze(dim=0), - "lidar_aug_matrix_inverse": lidar_aug_matrix.unsqueeze(dim=0), - "geom_feats": (geom_feats, kept, ranks, indices), - } - ) - - outputs = self._forward(batch_inputs_dict, data_samples, **kwargs) - - # The following code is taken from - # projects/BEVFusion/bevfusion/bevfusion_head.py - # It is used to simplify the post process in deployment - score = outputs["heatmap"].sigmoid() - one_hot = F.one_hot(outputs["query_labels"], num_classes=score.size(1)).permute(0, 2, 1) - score = score * outputs["query_heatmap_score"] * one_hot - score = score[0].max(dim=0)[0] - - bbox_pred = torch.cat( - [outputs["center"][0], outputs["height"][0], outputs["dim"][0], outputs["rot"][0], outputs["vel"][0]], dim=0 - ) - - return bbox_pred, score, outputs["query_labels"][0] diff --git a/projects/BEVFusion/deploy/export.py b/projects/BEVFusion/deploy/export.py deleted file mode 100644 index 702884a1..00000000 --- a/projects/BEVFusion/deploy/export.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import argparse -import logging -import os -import os.path as osp -from copy import deepcopy -from functools import partial -from typing import Any - -import numpy as np -import onnx -import torch -from mmdeploy.apis import build_task_processor -from mmdeploy.apis.onnx.passes import optimize_onnx -from mmdeploy.core import RewriterContext, patch_model -from mmdeploy.utils import ( - IR, - Backend, - get_backend, - get_dynamic_axes, - get_ir_config, - get_onnx_config, - get_root_logger, - load_config, -) -from mmdet3d.registry import MODELS -from mmengine.registry import RUNNERS -from torch.multiprocessing import set_start_method - - -def parse_args(): - parser = argparse.ArgumentParser(description="Export model to onnx.") - parser.add_argument("deploy_cfg", help="deploy config path") - parser.add_argument("model_cfg", help="model config path") - parser.add_argument("checkpoint", help="model checkpoint path") - parser.add_argument("--work-dir", default=os.getcwd(), help="the dir to save logs and models") - parser.add_argument("--device", help="device used for conversion", default="cpu") - parser.add_argument("--log-level", help="set log level", default="INFO", choices=list(logging._nameToLevel.keys())) - parser.add_argument("--sample_idx", type=int, default=0, help="sample index to use during export") - args = parser.parse_args() - return args - - -if __name__ == "__main__": - args = parse_args() - set_start_method("spawn", force=True) - logger = get_root_logger() - log_level = logging.getLevelName(args.log_level) - logger.setLevel(log_level) - - deploy_cfg_path = args.deploy_cfg - model_cfg_path = args.model_cfg - checkpoint_path = args.checkpoint - device = args.device - work_dir = args.work_dir - - deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) - model_cfg.launcher = "none" - - data_preprocessor_cfg = deepcopy(model_cfg.model.data_preprocessor) - - voxelize_cfg = data_preprocessor_cfg.pop("voxelize_cfg") - voxelize_cfg.pop("voxelize_reduce") - data_preprocessor_cfg["voxel_layer"] = voxelize_cfg - data_preprocessor_cfg.voxel = True - - data_preprocessor = MODELS.build(data_preprocessor_cfg) - - # load a sample - runner = RUNNERS.build(model_cfg) - runner.load_or_resume() - - data = runner.test_dataloader.dataset[args.sample_idx] - - # create model an inputs - task_processor = build_task_processor(model_cfg, deploy_cfg, device) - - torch_model = task_processor.build_pytorch_model(checkpoint_path) - data, model_inputs = task_processor.create_input(data, data_preprocessor=data_preprocessor, model=torch_model) - - if isinstance(model_inputs, list) and len(model_inputs) == 1: - model_inputs = model_inputs[0] - data_samples = data["data_samples"] - input_metas = {"data_samples": data_samples, "mode": "predict", "data_preprocessor": data_preprocessor} - - # export to onnx - context_info = dict() - context_info["deploy_cfg"] = deploy_cfg - output_prefix = osp.join(work_dir, osp.splitext(osp.basename(deploy_cfg.onnx_config.save_file))[0]) - backend = get_backend(deploy_cfg).value - - onnx_cfg = get_onnx_config(deploy_cfg) - opset_version = onnx_cfg.get("opset_version", 11) - - input_names = onnx_cfg["input_names"] - output_names = onnx_cfg["output_names"] - axis_names = input_names + output_names - dynamic_axes = get_dynamic_axes(deploy_cfg, axis_names) - verbose = not onnx_cfg.get("strip_doc_string", True) or onnx_cfg.get("verbose", False) - keep_initializers_as_inputs = onnx_cfg.get("keep_initializers_as_inputs", True) - optimize = onnx_cfg.get("optimize", False) - if backend == Backend.NCNN.value: - """NCNN backend needs a precise blob counts, while using onnx optimizer - will merge duplicate initilizers without reference count.""" - optimize = False - - output_path = output_prefix + ".onnx" - - logger = get_root_logger() - logger.info(f"Export PyTorch model to ONNX: {output_path}.") - - def _add_or_update(cfg: dict, key: str, val: Any): - if key in cfg and isinstance(cfg[key], dict) and isinstance(val, dict): - cfg[key].update(val) - else: - cfg[key] = val - - ir_config = dict( - type="onnx", - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - verbose=verbose, - keep_initializers_as_inputs=keep_initializers_as_inputs, - ) - _add_or_update(deploy_cfg, "ir_config", ir_config) - ir = IR.get(get_ir_config(deploy_cfg)["type"]) - if isinstance(backend, Backend): - backend = backend.value - backend_config = dict(type=backend) - _add_or_update(deploy_cfg, "backend_config", backend_config) - - context_info["cfg"] = deploy_cfg - context_info["ir"] = ir - if "backend" not in context_info: - context_info["backend"] = backend - if "opset" not in context_info: - context_info["opset"] = opset_version - - # patch model - patched_model = patch_model(torch_model, cfg=deploy_cfg, backend=backend, ir=ir) - - if "onnx_custom_passes" not in context_info: - onnx_custom_passes = optimize_onnx if optimize else None - context_info["onnx_custom_passes"] = onnx_custom_passes - with RewriterContext(**context_info), torch.no_grad(): - # patch input_metas - if input_metas is not None: - assert isinstance(input_metas, dict), f"Expect input_metas type is dict, get {type(input_metas)}." - model_forward = patched_model.forward - - def wrap_forward(forward): - - def wrapper(*arg, **kwargs): - return forward(*arg, **kwargs) - - return wrapper - - patched_model.forward = wrap_forward(patched_model.forward) - patched_model.forward = partial(patched_model.forward, **input_metas) - - # NOTE(knzo25): export on the selected device. - # the original code forced cpu - patched_model = patched_model.to(device) - if isinstance(model_inputs, torch.Tensor): - model_inputs = model_inputs.to(device) - elif isinstance(model_inputs, (tuple, list)): - model_inputs = tuple([_.to(device) for _ in model_inputs]) - else: - raise RuntimeError(f"Not supported model_inputs: {model_inputs}") - torch.onnx.export( - patched_model, - model_inputs, - output_path, - export_params=True, - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - keep_initializers_as_inputs=keep_initializers_as_inputs, - verbose=verbose, - ) - - if input_metas is not None: - patched_model.forward = model_forward - - logger.info(f"ONNX exported to {output_path}") - - logger.info("Attempting to fix the graph (TopK's K becoming a tensor)") - - import onnx_graphsurgeon as gs - - model = onnx.load(output_path) - graph = gs.import_onnx(model) - - # Fix TopK - topk_nodes = [node for node in graph.nodes if node.op == "TopK"] - assert len(topk_nodes) == 1 - topk = topk_nodes[0] - k = model_cfg.num_proposals - topk.inputs[1] = gs.Constant("K", values=np.array([k], dtype=np.int64)) - topk.outputs[0].shape = [1, k] - topk.outputs[0].dtype = topk.inputs[0].dtype if topk.inputs[0].dtype else np.float32 - topk.outputs[1].shape = [1, k] - topk.outputs[1].dtype = np.int64 - - graph.cleanup().toposort() - output_path = output_path.replace(".onnx", "_fixed.onnx") - onnx.save_model(gs.export_onnx(graph), output_path) - - logger.info(f"(Fixed) ONNX exported to {output_path}")