- 
                Notifications
    You must be signed in to change notification settings 
- Fork 67
Add FMHA PAXML test #830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add FMHA PAXML test #830
Conversation
        
          
                .github/container/test-pax.sh
              
                Outdated
          
        
      | fmha_regex="fmha[-bmm]?[-scale]?[-bias]?[-mask]?[-softmax]?[-dropout]?[-bmm]?[-backward]?*" | ||
| result=$(grep -irlnE "$fmha_regex" "${HLO_DIR}/"*.txt) | ||
|  | ||
| if [[ $SAVE_HLO -eq 0 ]]; then | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect to skip saving if $SAVE_HLO == 0, rather than delete it afterwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for fmha the hlo is dumped in output folder by default to search for fmha instructions. The idea is to delete the hlo folder if user doesn't want to have hlo folder inside the output folder which is saved as artifact.
|  | ||
| # Set hlo dump folder after output folder is set. | ||
| HLO_DIR=${OUTPUT}/hlo | ||
| export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_dump_hlo_as_text --xla_dump_to=${HLO_DIR}}" | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain logic here: is BASE_XLA_FLAGS is set, than you always skip setting HLO_DIR?
If so, maybe you can add a warning message, that xla dump is not set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dumping the hlo is enabled by default in BASE_XLA_FLAGS, and BASE_XLA_FLAGS are appended to XLA_FLAGS env var. if user wants to test fmha then BASE_XLA_FLAGS_FMHA is added and appended to XLA_FLAGS. The idea is to preserve the env var XLA_FLAGS before execution of this script.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, let me clarify my question:
line 150 literally means:
if [[ -z "$BASE_XLA_FLAGS"  ]]; then
      BASE_XLA_FLAGS = "--xla_dump_hlo_as_text --xla_dump_to=${HLO_DIR}}"
fi
Meaning, that if BASE_XLA_FLAGS is already set (by any previous scripts, or globally in the system, etc), ${HLO_DIR} will not have any effect at all.
Is that expected behaviour?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And why do you export it? You use it only locally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mechanism was added as per the review comment of same PR for t5x: #442 (comment)
refer to the discussion for details of the implementation.
The implementation BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_dump_hlo_as_text --xla_dump_to=${HLO_DIR}}" means update the BASE_XLA_FLAGS with previous definition if any and append xla dump hlo flags to the env vars. This also gives us the flexibility of "zero out" the env var in this script without modifying code in this script by just doing BASE_XLA_FLAGS=""
| ## Setting the env variables for FMHA | ||
| if [[ "$ENABLE_FMHA" -eq "1" ]]; then | ||
| echo "Setting XLA FMHA Flags"; | ||
| export BASE_XLA_FLAGS_FMHA="${BASE_XLA_FLAGS_FMHA:---xla_gpu_fused_attention_use_cudnn_rng=true --xla_gpu_enable_cudnn_fmha=true}" | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Save here as above
Disabled saving hlo by default as suggested by terry
Incorporated review comments, disabled saving hlo by default as suggested by terry.
merge enable-fmha and enable-fused-attn flags
Incorporated review comments
| LGTM after clarifying about extra XLA_FLAGS | 
| Can we merge this PR if it looks good? | 
| Changes look good to me. The presubmit CI did not complete, so I reran. I'll approve once CI completes and we verify that results are as expected | 
| echo " --multiprocess Enable the multiprocess GPU mode." | ||
| echo " -o, --output NAME Name for the output folder, a temporary folder will be created if none specified." | ||
| echo " --save-hlo {0, 1} 1 to save the dumped hlo, 0 to remove the hlo dumped folder" | ||
| echo " --enable-fmha {0, 1} 1 to enable fmha testing, 0 to run test without fmha; default is 0" | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: default doesn't match below
Adding the JAX PAXML FMHA E2E system test to check for fmha lowering support. Following are the steps implemented in the test:
FMHA lowering flag is enabled, enabled the dumping of hlo to track fmha forward and backward instructions.
Enabled dumping of HLO by default. for Llama test, disabled the hlo dumoing as the size of the artifact is big.
Similar to PR: #442