this paper describes one-pass calculation of the max and divisor, so you can do softmax overall in two passes: https://arxiv.org/pdf/1805.02867
I found this paper from this blog post: https://peterchng.com/blog/2024/06/26/the-basic-idea-behind-flashattention/
perhaps unsurprisingly, this algorithm does not need to be computed serially. its states are mergeable like welford states. the merge formula is pretty obvious I think. Edit: Dang, I scrolled further in the paper and they just tell you the formula.
this paper describes one-pass calculation of the max and divisor, so you can do softmax overall in two passes: https://arxiv.org/pdf/1805.02867
I found this paper from this blog post: https://peterchng.com/blog/2024/06/26/the-basic-idea-behind-flashattention/
perhaps unsurprisingly, this algorithm does not need to be computed serially. its states are mergeable like welford states. the merge formula is pretty obvious I think. Edit: Dang, I scrolled further in the paper and they just tell you the formula.