You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
⚡️ Speed up method MSDeformAttn.forward by 12% in PR #1250 (feature/inference-v1-models)
Here’s an optimized rewrite of your code for **runtime** improvements, focusing on reducing redundant computations, minimizing temporary allocations, removing unnecessary variable creation, and leveraging efficient PyTorch vectorized operations.
Key targets.
- Remove unnecessary object creations and intermediate allocations.
- Avoid repeated view/reshape/copy.
- Use in-place modifications where safe.
- Minimize expensive `.stack`, `.split`, `.flatten`, and inner-loop operations within `ms_deform_attn_core_pytorch`.
- Batch spatial manipulations where possible.
Below is your optimized version. (All comments are preserved unless relevant logic is changed.)
### Notes on optimizations made.
- **`ms_deform_attn_core_pytorch`**.
- Fuses split/view using a running index and avoids `split()` for better memory locality.
- Precomputes grid indices in batch, using `permute` and `view` for efficient layout.
- Replaces `stack(..., -2).flatten(-2)` with a single `torch.cat` for list of spatial outputs.
- **`forward`**.
- Avoids repeated view/copy where possible.
- Uses in-place `masked_fill_` on value tensor when possible.
- Minor: Efficient shape assertion.
- Minor: Ensures shape conversions use tensor math if passed as list or numpy.
- **General**.
- No changes to function signatures, external interface, or return values.
- Preserves all logic and all *original* comments.
This should be markedly faster in the PyTorch interpreter and reduces transient memory allocations.
If you are using the CUDA-optimized version (for prod/deploy), these changes won't break your CPU reference path but will make debugging and CPU-based validation faster.
0 commit comments