Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions projects/BEVFusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ python projects/BEVFusion/deploy/torch2onnx.py \
${MODEL_CFG} \
${CHECKPOINT_PATH} \
--device cuda:0 \
--work-dir ${WORK_DIR}
--work-dir ${WORK_DIR} \
--module main_body


Expand All @@ -140,7 +140,7 @@ python projects/BEVFusion/deploy/torch2onnx.py \
${MODEL_CFG} \
${CHECKPOINT_PATH} \
--device cuda:0 \
--work-dir ${WORK_DIR}
--work-dir ${WORK_DIR} \
--module image_backbone

```
Expand Down
76 changes: 46 additions & 30 deletions projects/BEVFusion/bevfusion/bevfusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def __init__(

self.init_weights()

def _forward(self, batch_inputs_dict: Tensor, batch_data_samples: OptSampleList = None, **kwargs):
def _forward(
self, batch_inputs_dict: Tensor, batch_data_samples: OptSampleList = [], using_image_features=False, **kwargs
):
"""Network forward process.

Usually includes backbone, neck and head forward without any post-
Expand All @@ -66,7 +68,7 @@ def _forward(self, batch_inputs_dict: Tensor, batch_data_samples: OptSampleList

# NOTE(knzo25): this is used during onnx export
batch_input_metas = [item.metainfo for item in batch_data_samples]
feats = self.extract_feat(batch_inputs_dict, batch_input_metas)
feats = self.extract_feat(batch_inputs_dict, batch_input_metas, using_image_features)

if self.with_bbox_head:
outputs = self.bbox_head(feats, batch_input_metas)
Expand Down Expand Up @@ -122,6 +124,21 @@ def with_seg_head(self):
"""bool: Whether the detector has a segmentation head."""
return hasattr(self, "seg_head") and self.seg_head is not None

def get_image_backbone_features(self, x: torch.Tensor) -> torch.Tensor:
B, N, C, H, W = x.size()
x = x.view(B * N, C, H, W).contiguous()

x = self.img_backbone(x)
x = self.img_neck(x)

if not isinstance(x, torch.Tensor):
x = x[0]

BN, C, H, W = x.size()
assert BN == B * N, (BN, B * N)
x = x.view(B, N, C, H, W)
return x

def extract_img_feat(
self,
x,
Expand All @@ -136,19 +153,11 @@ def extract_img_feat(
img_aug_matrix_inverse=None,
lidar_aug_matrix_inverse=None,
geom_feats=None,
using_image_features=False,
) -> torch.Tensor:
B, N, C, H, W = x.size()
x = x.view(B * N, C, H, W).contiguous()

x = self.img_backbone(x)
x = self.img_neck(x)

if not isinstance(x, torch.Tensor):
x = x[0]

BN, C, H, W = x.size()
assert BN == B * N, (BN, B * N)
x = x.view(B, N, C, H, W)
if not using_image_features:
x = self.get_image_backbone_features(x)

with torch.cuda.amp.autocast(enabled=False):
# with torch.autocast(device_type='cuda', dtype=torch.float32):
Expand Down Expand Up @@ -219,7 +228,11 @@ def voxelize(self, points):
return feats, coords, sizes

def predict(
self, batch_inputs_dict: Dict[str, Optional[Tensor]], batch_data_samples: List[Det3DDataSample], **kwargs
self,
batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample],
using_image_features=False,
**kwargs,
) -> List[Det3DDataSample]:
"""Forward of testing.

Expand All @@ -246,7 +259,7 @@ def predict(
contains a tensor with shape (num_instances, 7).
"""
batch_input_metas = [item.metainfo for item in batch_data_samples]
feats = self.extract_feat(batch_inputs_dict, batch_input_metas)
feats = self.extract_feat(batch_inputs_dict, batch_input_metas, using_image_features)

if self.with_bbox_head:
outputs = self.bbox_head.predict(feats, batch_input_metas)
Expand All @@ -259,11 +272,14 @@ def extract_feat(
self,
batch_inputs_dict,
batch_input_metas,
using_image_features,
**kwargs,
):
imgs = batch_inputs_dict.get("imgs", None)
points = batch_inputs_dict.get("points", None)
features = []

is_onnx_inference = False
if imgs is not None and "lidar2img" not in batch_inputs_dict:
# NOTE(knzo25): normal training and testing
imgs = imgs.contiguous()
Expand All @@ -290,47 +306,43 @@ def extract_feat(
img_aug_matrix,
lidar_aug_matrix,
batch_input_metas,
using_image_features=using_image_features,
)
features.append(img_feature)
elif imgs is not None:
# NOTE(knzo25): onnx inference
is_onnx_inference = True
lidar2image = batch_inputs_dict["lidar2img"]
camera_intrinsics = batch_inputs_dict["cam2img"]
camera2lidar = batch_inputs_dict["cam2lidar"]
img_aug_matrix = batch_inputs_dict["img_aug_matrix"]
lidar_aug_matrix = batch_inputs_dict["lidar_aug_matrix"]

# NOTE(knzo25): originally BEVFusion uses all the points
# which could be a bit slow. For now I am using only
# the centroids, which is also suboptimal, but using
# all the voxels produce errors in TensorRT,
# so this will be fixed for the next version
# (ScatterElements bug, or simply null voxels break the equation)
feats = batch_inputs_dict["voxels"]["voxels"]
sizes = batch_inputs_dict["voxels"]["num_points_per_voxel"]
geom_feats = batch_inputs_dict["geom_feats"]

feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(-1, 1)
# feats = batch_inputs_dict["voxels"]["voxels"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

# sizes = batch_inputs_dict["voxels"]["num_points_per_voxel"]
# feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(-1, 1)

geom_feats = batch_inputs_dict["geom_feats"]
img_feature = self.extract_img_feat(
imgs,
[feats],
# points,
points,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
batch_input_metas,
geom_feats=geom_feats,
using_image_features=using_image_features,
)
features.append(img_feature)

pts_feature = self.extract_pts_feat(
batch_inputs_dict.get("voxels", {}).get("voxels", None),
batch_inputs_dict.get("voxels", {}).get("coors", None),
batch_inputs_dict.get("voxels", {}).get("num_points_per_voxel", None),
points=points,
points=points if not is_onnx_inference else None,
)
features.append(pts_feature)

Expand All @@ -346,10 +358,14 @@ def extract_feat(
return x

def loss(
self, batch_inputs_dict: Dict[str, Optional[Tensor]], batch_data_samples: List[Det3DDataSample], **kwargs
self,
batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample],
using_image_features: bool = False,
**kwargs,
) -> List[Det3DDataSample]:
batch_input_metas = [item.metainfo for item in batch_data_samples]
feats = self.extract_feat(batch_inputs_dict, batch_input_metas)
feats = self.extract_feat(batch_inputs_dict, batch_input_metas, using_image_features)

losses = dict()
if self.with_bbox_head:
Expand Down
28 changes: 14 additions & 14 deletions projects/BEVFusion/bevfusion/depth_lss.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,20 +314,6 @@ def forward(
lidar_aug_matrix_inverse,
geom_feats_precomputed,
):
post_trans = img_aug_matrix[..., :3, 3]
camera2lidar_rots = camera2lidar[..., :3, :3]
camera2lidar_trans = camera2lidar[..., :3, 3]

if camera_intrinsics_inverse is None:
intrins_inverse = torch.inverse(cam_intrinsic)[..., :3, :3]
else:
intrins_inverse = camera_intrinsics_inverse[..., :3, :3]

if img_aug_matrix_inverse is None:
post_rots_inverse = torch.inverse(img_aug_matrix)[..., :3, :3]
else:
img_aug_matrix_inverse = img_aug_matrix_inverse[..., :3, :3]

if lidar_aug_matrix_inverse is None:
lidar_aug_matrix_inverse = torch.inverse(lidar_aug_matrix)

Expand Down Expand Up @@ -406,6 +392,20 @@ def forward(

x = self.bev_pool_precomputed(x, geom_feats, kept, ranks, indices)
else:
post_trans = img_aug_matrix[..., :3, 3]
camera2lidar_rots = camera2lidar[..., :3, :3]
camera2lidar_trans = camera2lidar[..., :3, 3]

if camera_intrinsics_inverse is None:
intrins_inverse = torch.inverse(cam_intrinsic)[..., :3, :3]
else:
intrins_inverse = camera_intrinsics_inverse[..., :3, :3]

if img_aug_matrix_inverse is None:
post_rots_inverse = torch.inverse(img_aug_matrix)[..., :3, :3]
else:
post_rots_inverse = img_aug_matrix_inverse[..., :3, :3]

geom = self.get_geometry(
camera2lidar_rots,
camera2lidar_trans,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,18 @@
allow_failed_imports=False,
)

image_dims = (384, 576)

backend_config = dict(
type="tensorrt",
common_config=dict(max_workspace_size=1 << 32),
model_inputs=[
dict(
input_shapes=dict(
imgs=dict(min_shape=[1, 3, 256, 704], opt_shape=[6, 3, 256, 704], max_shape=[6, 3, 256, 704]),
points=dict(min_shape=[5000, 5], opt_shape=[50000, 5], max_shape=[200000, 5]),
lidar2image=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]),
cam2image_inverse=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]),
camera2lidar=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]),
img_aug_matrix=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]),
img_aug_matrix_inverse=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]),
lidar_aug_matrix=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]),
lidar_aug_matrix_inverse=dict(min_shape=[1, 4, 4], opt_shape=[6, 4, 4], max_shape=[6, 4, 4]),
geom_feats=dict(
min_shape=[0 * 118 * 32 * 88, 4],
opt_shape=[6 * 118 * 32 * 88 // 2, 4],
max_shape=[6 * 118 * 32 * 88, 4],
),
kept=dict(
min_shape=[0 * 118 * 32 * 88],
opt_shape=[6 * 118 * 32 * 88],
max_shape=[6 * 118 * 32 * 88],
),
ranks=dict(
min_shape=[0 * 118 * 32 * 88],
opt_shape=[6 * 118 * 32 * 88 // 2],
max_shape=[6 * 118 * 32 * 88],
),
indices=dict(
min_shape=[0 * 118 * 32 * 88],
opt_shape=[6 * 118 * 32 * 88 // 2],
max_shape=[6 * 118 * 32 * 88],
imgs=dict(
min_shape=[1, 3, image_dims[0], image_dims[1]],
opt_shape=[6, 3, image_dims[0], image_dims[1]],
max_shape=[6, 3, image_dims[0], image_dims[1]],
),
)
)
Expand All @@ -57,64 +35,12 @@
save_file="image_backbone.onnx",
input_names=[
"imgs",
"points",
"lidar2image",
"cam2image",
"cam2image_inverse",
"camera2lidar",
"img_aug_matrix",
"img_aug_matrix_inverse",
"lidar_aug_matrix",
"lidar_aug_matrix_inverse",
"geom_feats",
"kept",
"ranks",
"indices",
],
output_names=["image_feats"],
dynamic_axes={
"imgs": {
0: "num_imgs",
},
"points": {
0: "num_points",
},
"lidar2image": {
0: "num_imgs",
},
"cam2image": {
0: "num_imgs",
},
"cam2image_inverse": {
0: "num_imgs",
},
"camera2lidar": {
0: "num_imgs",
},
"img_aug_matrix": {
0: "num_imgs",
},
"img_aug_matrix_inverse": {
0: "num_imgs",
},
"lidar_aug_matrix": {
0: "num_imgs",
},
"lidar_aug_matrix_inverse": {
0: "num_imgs",
},
"geom_feats": {
0: "num_kept",
},
"kept": {
0: "num_geom_feats",
},
"ranks": {
0: "num_kept",
},
"indices": {
0: "num_kept",
},
},
input_shape=None,
verbose=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
model_inputs=[
dict(
input_shapes=dict(
voxels=dict(min_shape=[1, 5], opt_shape=[64000, 5], max_shape=[256000, 5]),
coors=dict(min_shape=[1, 4], opt_shape=[64000, 4], max_shape=[256000, 4]),
voxels=dict(min_shape=[1, 10, 4], opt_shape=[64000, 10, 4], max_shape=[256000, 10, 4]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape should be [M, maximum number of points, features], which it will be [M, 10, 5] if we are using intensity right?

coors=dict(min_shape=[1, 3], opt_shape=[64000, 3], max_shape=[256000, 3]),
num_points_per_voxel=dict(min_shape=[1], opt_shape=[64000], max_shape=[256000]),
)
)
Expand Down
Loading