Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Attention Microsoft Contrib Operator #3816

Draft
wants to merge 25 commits into
base: develop
Choose a base branch
from

Conversation

@TedThemistokleous TedThemistokleous added roadmap Tasks to finish for a release onnxruntime PR changes interaction between MIGraphX and Onnxruntime Onnx Operators Adding or modifying an Onnx Operator in the MIGraphX codebase labels Feb 14, 2025
@TedThemistokleous TedThemistokleous self-assigned this Feb 14, 2025
@TedThemistokleous TedThemistokleous linked an issue Feb 20, 2025 that may be closed by this pull request
3 tasks
@causten causten added the high priority A PR with high priority for review and merging. label Feb 21, 2025
breaking this up to smaller pieces for optional args as I populate the proper vector inputs before tying things to the calculation and creation of multi head attention layers
need to finish with other input args and check infered and parsed attributes.
@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
4ef0ba
Rate old
b19d13
Diff Compare
torchvision-resnet50 64 3,235.23 3,237.11 -0.06%
torchvision-resnet50_fp16 64 6,877.89 6,877.59 0.00%
torchvision-densenet121 32 2,436.24 2,437.56 -0.05%
torchvision-densenet121_fp16 32 4,191.26 4,202.03 -0.26%
torchvision-inceptionv3 32 1,613.81 1,614.89 -0.07%
torchvision-inceptionv3_fp16 32 2,678.50 2,677.87 0.02%
cadene-inceptionv4 16 750.74 750.76 -0.00%
cadene-resnext64x4 16 808.97 810.16 -0.15%
slim-mobilenet 64 6,659.44 6,665.81 -0.10%
slim-nasnetalarge 64 198.55 196.84 0.87%
slim-resnet50v2 64 3,427.56 3,431.69 -0.12%
bert-mrpc-onnx 8 1,142.31 1,141.15 0.10%
bert-mrpc-tf 1 488.22 482.99 1.08%
pytorch-examples-wlang-gru 1 475.95 486.63 -2.19%
pytorch-examples-wlang-lstm 1 453.43 443.83 2.16%
torchvision-resnet50_1 1 812.47 806.76 0.71%
cadene-dpn92_1 1 430.25 430.96 -0.17%
cadene-resnext101_1 1 391.05 392.23 -0.30%
onnx-taau-downsample 1 371.32 371.19 0.03%
dlrm-criteoterabyte 1 31.83 31.79 0.13%
dlrm-criteoterabyte_fp16 1 50.95 51.07 -0.25%
agentmodel 1 8,794.30 8,932.40 -1.55%
unet_fp16 2 57.96 58.29 -0.57%
resnet50v1_fp16 1 1,039.93 1,045.16 -0.50%
resnet50v1_int8 1 798.90 804.91 -0.75%
bert_base_cased_fp16 64 1,164.44 1,164.48 -0.00%
bert_large_uncased_fp16 32 361.61 361.61 0.00%
bert_large_fp16 1 200.69 200.06 0.31%
distilgpt2_fp16 16 2,217.81 2,212.16 0.26%
yolov5s 1 520.73 522.38 -0.32%
tinyllama 1 43.59 43.59 -0.00%
vicuna-fastchat 1 43.96 43.89 0.17%
whisper-tiny-encoder 1 412.14 412.16 -0.00%
whisper-tiny-decoder 1 411.51 408.89 0.64%
yolov10 1 nan nan nan%
llama2_7b 1 nan nan nan%
qwen1.5-7b 1 nan nan nan%
phi3-3.8b 1 nan nan nan%
mask-rcnn 1 nan nan nan%
llama3-8b 1 nan nan nan%
whisper-large-encoder 1 nan nan nan%
whisper-large-decoder 1 nan nan nan%
mistral-7b 1 nan nan nan%
FLUX.1-schnell 1 nan nan nan%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

❌llama2_7b: ERROR - check error outputusage: accuracy_checker.py [-h] [--onnx ONNX] [--tf TF] [--provider PROVIDER]
[--batch BATCH] [--fill1] [--fill0] [--fp16]
[--argmax] [--verbose] [--tolerance TOLERANCE]
[--input-dim INPUT_DIM] [--target TARGET]
[--ort-run] [--ort-logging]
[--disable-offload-copy] [--disable-fast-math]
[--exhaustive_tune]
accuracy_checker.py: error: unrecognized arguments: input_ids attention_mask 1 256 @attention_mask 1 256


❌qwen1.5-7b: ERROR - check error outputusage: accuracy_checker.py [-h] [--onnx ONNX] [--tf TF] [--provider PROVIDER]
[--batch BATCH] [--fill1] [--fill0] [--fp16]
[--argmax] [--verbose] [--tolerance TOLERANCE]
[--input-dim INPUT_DIM] [--target TARGET]
[--ort-run] [--ort-logging]
[--disable-offload-copy] [--disable-fast-math]
[--exhaustive_tune]
accuracy_checker.py: error: unrecognized arguments: input_ids attention_mask position_ids 1 256 @attention_mask 1 256 @position_ids 1 256


❌phi3-3.8b: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/phi3-3.8b/model.onnx


❌mask-rcnn: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/mask-rcnn/MaskRCNN-10.onnx


❌llama3-8b: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/llama3-8b/model.onnx


❌whisper-large-encoder: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/whisper-large/encoder_model.onnx


❌whisper-large-decoder: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/whisper-large/decoder_model.onnx


❌mistral-7b: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/mistral-7b/model.onnx


❌FLUX.1-schnell: ERROR - check error outputTraceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 340, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 205, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /src/AMDMIGraphX/src/onnx/onnx_parser.cpp:264: parse_from: PARSE_FROM: Failed reading onnx file: /new-saved-models/FLUX.1-schnell/text_encoder/model.onnx

split this up to clean up the handle_inputs call and seperate errors/state when we aquire attributes
…nput correctly.

Need to fill in parser piece but this checks and ensures we're working with the proper batch size for our calculations within the attention head.

Debug still needs to be removedb but this ensures we're seeing the proper amount of heads that are batched correctly
…r now.

add some sort of tracked state for padding modes of the mask_index for input linear layer masking prior to attention head splits.
Too much Cpp too little python
Tests more representative of customer workloads and models we see in the wild. Need to finish these to complete parseer tests.

Will add tests for other inputs and error cases later
Give an explanation to how things are parsed in as the input sizes of masks, inputs, weights as well as attributes can change how certain infered values in the parser can be calculated. This is due to how the spec specifices how inputs will be handled on parse.
clean up debug from input_linear_to_qkv and have input be put in via vector if instructions.
@TedThemistokleous TedThemistokleous force-pushed the add_attention_contrib_op branch from 32f91e4 to a7f8a3b Compare March 1, 2025 19:56
Ted Themistokleous added 2 commits March 4, 2025 23:03
Use the default query size for scale factor if scale attribute is not set. Flows through the result accordingly.
Leave this as a todo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority A PR with high priority for review and merging. Onnx Operators Adding or modifying an Onnx Operator in the MIGraphX codebase onnxruntime PR changes interaction between MIGraphX and Onnxruntime roadmap Tasks to finish for a release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Parser for Attention Contrib OP
3 participants