-
Notifications
You must be signed in to change notification settings - Fork 42
Optimize roi_align on BMG #1698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR aims to optimize the roi_align performance on BMG by reducing repeated LLC memory accesses and streamlining conditional execution. Key changes include refactoring boundaries and conditional checks in the upsample kernels, and enhancing workgroup-based caching and indexing in the roi_align implementation.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp | Refined boundary condition handling and restructured the can_optimize condition |
src/ATen/native/xpu/sycl/RoiAlignKernels.cpp | Updated bilinear interpolation clamping and improved ROI workgroup indexing with shared memory caching |
Comments suppressed due to low confidence (1)
src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp:608
- Consider refactoring this compound conditional for 'can_optimize' to improve readability and maintainability, perhaps by extracting it into a helper function if it is reused.
can_optimize = can_optimize && (align_corners || (input_width == (rwidth * output_width) &&
For input [1, 2048, 50, 75], rois [1000,5], roi align takes 4.7 ms on PVC but 75 ms on BMG. Each roi will have 2048xoutput_hxoutput_w work items reading the same value from LLC, and it's very slow on BMG. After put them into shared local memory, PVC takes 4.0ms, BMG reaches 7.5ms. I also removed some if else branching by min/max. I also fix a code style issue.