Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fp16 ONNX export for RT-DETR and RT-DETRv2 #36460

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

qubvel
Copy link
Member

@qubvel qubvel commented Feb 27, 2025

What does this PR do?

Fix fp16 ONNX export for RT-DETR and RT-DETRv2, related to

Comment on lines 1312 to 1313
grid_w = torch.arange(int(width), device=device).to(dtype)
grid_h = torch.arange(int(height), device=device).to(dtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Range op in fp16 is not supported by ONNX, so we use it with the default type and then cast it to the desired data type.

Comment on lines +1728 to +1729
torch.arange(end=height, device=device).to(dtype),
torch.arange(end=width, device=device).to(dtype),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same about Range

@qubvel
Copy link
Member Author

qubvel commented Feb 27, 2025

run-slow: rt_detr, rt_detr_v2

Copy link

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/rt_detr', 'models/rt_detr_v2']
quantizations: [] ...

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +1292 to +1319
for _ in range(self.num_fpn_stages):
lateral_conv = RTDetrConvNormLayer(
config,
in_channels=self.encoder_hidden_dim,
out_channels=self.encoder_hidden_dim,
kernel_size=1,
stride=1,
activation=activation,
)
self.fpn_blocks.append(RTDetrCSPRepLayer(config))
fpn_block = RTDetrCSPRepLayer(config)
self.lateral_convs.append(lateral_conv)
self.fpn_blocks.append(fpn_block)

# bottom-up pan
# bottom-up PAN
self.downsample_convs = nn.ModuleList()
self.pan_blocks = nn.ModuleList()
for _ in range(len(self.in_channels) - 1):
self.downsample_convs.append(
RTDetrConvNormLayer(
config, self.encoder_hidden_dim, self.encoder_hidden_dim, 3, 2, activation=activation_function
)
for _ in range(self.num_pan_stages):
downsample_conv = RTDetrConvNormLayer(
config,
in_channels=self.encoder_hidden_dim,
out_channels=self.encoder_hidden_dim,
kernel_size=3,
stride=2,
activation=activation,
)
self.pan_blocks.append(RTDetrCSPRepLayer(config))
pan_block = RTDetrCSPRepLayer(config)
self.downsample_convs.append(downsample_conv)
self.pan_blocks.append(pan_block)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just refactoring

Comment on lines +1425 to +1453
for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1]
top_fpn_feature_map = fpn_feature_maps[-1]
# apply lateral block
top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
fpn_feature_maps[-1] = top_fpn_feature_map
# apply fpn block
top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
new_fpn_feature_map = fpn_block(fused_feature_map)
fpn_feature_maps.append(new_fpn_feature_map)

fpn_feature_maps = fpn_feature_maps[::-1]

# bottom-up PAN
pan_feature_maps = [fpn_feature_maps[0]]
for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
top_pan_feature_map = pan_feature_maps[-1]
fpn_feature_map = fpn_feature_maps[idx + 1]
downsampled_feature_map = downsample_conv(top_pan_feature_map)
fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
new_pan_feature_map = pan_block(fused_feature_map)
pan_feature_maps.append(new_pan_feature_map)

if not return_dict:
return tuple(v for v in [fpn_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(last_hidden_state=fpn_states, hidden_states=encoder_states, attentions=all_attentions)
return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just refactoring, no changes

@qubvel qubvel marked this pull request as ready for review February 27, 2025 19:10
@qubvel
Copy link
Member Author

qubvel commented Feb 27, 2025

run-slow: rt_detr, rt_detr_v2

Copy link

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/rt_detr', 'models/rt_detr_v2']
quantizations: [] ...

@qubvel
Copy link
Member Author

qubvel commented Feb 27, 2025

run-slow: rt_detr, rt_detr_v2

Copy link

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/rt_detr', 'models/rt_detr_v2']
quantizations: [] ...

@qubvel
Copy link
Member Author

qubvel commented Feb 28, 2025

cc @xenova if you have bandwidth

Copy link
Contributor

@xenova xenova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just one comment about using the trace-safe torch_int helper function for typecasts.

I also left additional comments here to ensure trace-compatibility: huggingface/optimum#2201 (review)

Edit: Just tested the exports and indeed, there are some issues when running with shapes different than export size (e.g., export=320x320, runtime=480x320). Addressing huggingface/optimum#2201 (review) should fix this.

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'/model/encoder/Reshape' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:39 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, onnxruntime::TensorShapeVector&, bool) size != 0 && (input_shape_size % size) == 0 was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,256,15,10}, requested shape:{-1,256,100}

Comment on lines 989 to 990
grid_w = torch.arange(int(width), device=device).to(dtype)
grid_h = torch.arange(int(height), device=device).to(dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using python type casts (int(...) or float(...)) causes the tracer to lose information, so could you instead use the torch_int utility function (see here). It only has a difference when tracing.

Warning logs:

/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py:989: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  grid_w = torch.arange(int(width), device=device).to(dtype)
/usr/local/lib/python3.11/dist-packages/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py:990: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  grid_h = torch.arange(int(height), device=device).to(dtype)

Comment on lines 1325 to 1326
grid_w = torch.arange(int(width), device=device).to(dtype)
grid_h = torch.arange(int(height), device=device).to(dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as other comment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants