Skip to content

optimize the performance of FlashBert Path for HPU #575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2025

Conversation

kaixuanliu
Copy link
Contributor

@kaixuanliu kaixuanliu commented Apr 9, 2025

Use WhereIsAI/UAE-Large-V1 model to do benchmark, below is the throughput(seq/s) comparison:

bs before after
1 199.65 219.61
2 239.61 284.78
4 321.37 373.13
8 506.40 549.61
16 759.07 822.17
32 1028.31 1285.57
64 1130.67 1708.73
128 OOM 1030.06

@kaixuanliu
Copy link
Contributor Author

@regisss @Narsil pls help review, thx!

Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -323,19 +323,21 @@ def batch_type(self) -> Union[FlashBatch, PaddedBatch]:
def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]:
if isinstance(batch, PaddedBatch):
input_lens = batch.attention_mask.cumsum(-1)[:, -1].to(torch.int32)
max_input_lens = input_lens.max().item()
max_input_lens = 0 # This value will not be used
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_input_lens = 0 # This value will not be used

NIT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, sorry , there may be misunderstanding. Here I commented "This value will not be used" means this variable can be any value, but we need to keep it here, as we need to pass it to L352

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Narsil , can you help double check?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess there are cases where the forward of the model does need a right value for this right? Otherwise why not removing it there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this is a common file shared by CPU/XPU andd HPU devices. On CPU/XPU, we do need this variable with exact meaning, while on HPU, we do not have real varlen_attention API, so we pass attn_mask to replace its functionality. Here we just need to set a random value for max_input_lens. This line cannot be deleted, as we need to pass it to L352.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it 👍

@kaixuanliu
Copy link
Contributor Author

@regisss @Narsil Hi, can you help merge this PR?

@regisss regisss merged commit 5a791e5 into huggingface:main Apr 16, 2025
2 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants