diff --git a/problem_1.py b/problem_1.py index 683be1c..ab7dace 100644 --- a/problem_1.py +++ b/problem_1.py @@ -11,8 +11,8 @@ class FlashAttention2Function(torch.autograd.Function): @staticmethod def forward(ctx, Q, K, V, is_causal=False): # Get dimensions from input tensors following the (B, H, N, D) convention - B, H, N_Q, D_H = Q.shape - _, _, N_K, _ = K.shape + B, H, N_Q, D_H = Q.shape # N_Q is the number of query tokens, D_H is the hidden dimension + _, _, N_K, _ = K.shape # N_K is the number of key tokens # Define tile sizes Q_TILE_SIZE = 128 @@ -41,9 +41,9 @@ def forward(ctx, Q, K, V, is_causal=False): Q_tile = Q_bh[q_start:q_end, :] # Initialize accumulators for this query tile - o_i = torch.zeros_like(Q_tile, dtype=Q.dtype) - l_i = torch.zeros(q_end - q_start, device=Q.device, dtype=torch.float32) - m_i = torch.full((q_end - q_start,), -float('inf'), device=Q.device, dtype=torch.float32) + o_i = torch.zeros_like(Q_tile, dtype=Q.dtype) #running weighted output - the actual attention output. + l_i = torch.zeros(q_end - q_start, device=Q.device, dtype=torch.float32) #running sum of exponentials (for narmalization) + m_i = torch.full((q_end - q_start,), -float('inf'), device=Q.device, dtype=torch.float32) #max value for trick: subtracting the max to prevent exp(x) from blowing up # Inner loop over key/value tiles for j in range(N_K_tiles): @@ -53,26 +53,32 @@ def forward(ctx, Q, K, V, is_causal=False): K_tile = K_bh[k_start:k_end, :] V_tile = V_bh[k_start:k_end, :] - S_ij = (Q_tile @ K_tile.transpose(-1, -2)) * scale - + S_ij = (Q_tile @ K_tile.transpose(-1, -2)) * scale # --- STUDENT IMPLEMENTATION REQUIRED HERE --- # 1. Apply causal masking if is_causal is True. - # + if is_causal: + q_idx = torch.arange(q_start, q_end, device=Q.device).unsqueeze(1) # (Tq, 1) #### [MAIN FIX] q_indx must be on the same device as S_ij, which is CUDA + k_idx = torch.arange(k_start, k_end, device=Q.device).unsqueeze(0) # (1, Tk) #### [MAIN FIX] k_indx must be on the same device as S_ij, which is CUDA + S_ij = S_ij.masked_fill( k_idx > q_idx, -float('inf')) #allow only k <= q [MAIN FIX] was k_inx < q_indx earlier # 2. Compute the new running maximum - # + m_ij = torch.max(S_ij, dim = -1).values.to(torch.float32) + m_new = torch.maximum(m_i, m_ij) # 3. Rescale the previous accumulators (o_i, l_i) - # + scale_factor = torch.exp(m_i - m_new) + o_i = o_i * scale_factor.unsqueeze(-1).to(o_i.dtype) + l_i = l_i * scale_factor # 4. Compute the probabilities for the current tile, P_tilde_ij = exp(S_ij - m_new). - # + P_tilde_ij = torch.exp(S_ij.to(torch.float32) - m_new.unsqueeze(-1)) #### [MAIN FIX] unsqueeze this # 5. Accumulate the current tile's contribution to the accumulators to update l_i and o_i - # + l_i = l_i + torch.sum(P_tilde_ij, dim=-1) + o_i = o_i + (P_tilde_ij @ V_tile.to(torch.float32)).to(o_i.dtype) # 6. Update the running max for the next iteration - + m_i = m_new # --- END OF STUDENT IMPLEMENTATION --- # After iterating through all key tiles, normalize the output # This part is provided for you. It handles the final division safely. - l_i_reciprocal = torch.where(l_i > 0, 1.0 / l_i, 0) + l_i_reciprocal = torch.where(l_i > 0, 1.0 / l_i, 0) # o_i_normalized = o_i * l_i_reciprocal.unsqueeze(-1) L_tile = m_i + torch.log(l_i) diff --git a/problem_2.py b/problem_2.py index 011f3eb..b2665f8 100644 --- a/problem_2.py +++ b/problem_2.py @@ -16,47 +16,47 @@ def weighted_row_sum_kernel( """ # 1. Get the row index for the current program instance. # Hint: Use tl.program_id(axis=0). - row_idx = ... + row_idx = tl.program_id(axis = 0) # 2. Create a pointer to the start of the current row in the input tensor X. # Hint: The offset depends on the row index and the number of columns (N_COLS). - row_start_ptr = ... - + row_start_ptr = X_ptr + row_idx * N_COLS + # 3. Create a pointer for the output vector Y. - output_ptr = ... + output_ptr = Y_ptr + row_idx # 4. Initialize an accumulator for the sum of the products for a block. # This should be a block-sized tensor of zeros. # Hint: Use tl.zeros with shape (BLOCK_SIZE,) and dtype tl.float32. - accumulator = ... + accumulator = tl.zeros((BLOCK_SIZE,), dtype = tl.float32) # 5. Iterate over the columns of the row in blocks of BLOCK_SIZE. # Hint: Use a for loop with tl.cdiv(N_COLS, BLOCK_SIZE). - for col_block_start in range(0, ...): + for col_block_start in range(0, tl.cdiv(N_COLS, BLOCK_SIZE)): # - Calculate the offsets for the current block of columns. # Hint: Start from the block's beginning and add tl.arange(0, BLOCK_SIZE). - col_offsets = ... + col_offsets = col_block_start * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # - Create a mask to prevent out-of-bounds memory access for the last block. # Hint: Compare col_offsets with N_COLS. - mask = ... + mask = col_offsets < N_COLS # - Load a block of data from X and W safely using the mask. # Hint: Use tl.load with the appropriate pointers, offsets, and mask. # Use `other=0.0` to handle out-of-bounds elements. - x_chunk = tl.load(...) - w_chunk = tl.load(...) + x_chunk = tl.load(row_start_ptr + col_offsets, mask = mask, other = 0.0) + w_chunk = tl.load(W_ptr + col_offsets, mask = mask, other = 0.0) # - Compute the element-wise product and add it to the accumulator. - accumulator += ... + accumulator += x_chunk * w_chunk # 6. Reduce the block-sized accumulator to a single scalar value after the loop. # Hint: Use tl.sum(). - final_sum = ... + final_sum = tl.sum(accumulator, axis = 0) # 7. Store the final accumulated sum to the output tensor Y. # Hint: Use tl.store(). - ... + tl.store(output_ptr, final_sum) # --- END OF STUDENT IMPLEMENTATION --- diff --git a/problem_3.py b/problem_3.py index 43f134a..1447cb5 100644 --- a/problem_3.py +++ b/problem_3.py @@ -39,8 +39,9 @@ def _flash_attention_forward_kernel( q_offsets = (q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M)) q_ptrs = Q_ptr + batch_idx * q_stride_b + head_idx * q_stride_h + \ (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :]) - q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0) - + # a block of a [BATCH_NUM, HEAD_NUM, q_offsets[:, None] * q_stride_s :, :] => contains the unprocessed tokens, masking here is to only process the limit number of tokens + q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0) + # PyTorch softmax is exp(x), Triton is exp2(x * log2(e)), log2(e) is approx 1.44269504 qk_scale = softmax_scale * 1.44269504 @@ -61,15 +62,25 @@ def _flash_attention_forward_kernel( (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) - # --- STUDENT IMPLEMENTATION REQUIRED HERE --- - # Implement the online softmax update logic. + # --- STUDENT IMPLEMENTATION REQUIRED HERE # 1. Find the new running maximum (`m_new`). + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) # 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`). + # Use tl.exp2 for Triton's online softmax formulation. + scale_factor = tl.exp2(m_i - m_new) + acc *= scale_factor[:, None] + l_i *= scale_factor # 3. Compute the attention probabilities for the current tile (`p_ij`). + p_ij = tl.exp2(s_ij - m_new[:, None]) # 4. Update the accumulator `acc` using `p_ij` and `v_block`. + acc += tl.dot(p_ij.to(v_block.type), v_block) # 5. Update the denominator `l_i`. + l_i += tl.sum(p_ij, axis=1) # 6. Update the running maximum `m_i` for the next iteration. + m_i = m_new pass + # --- END OF STUDENT IMPLEMENTATION --- diff --git a/problem_4.py b/problem_4.py index 4faf01b..040be34 100644 --- a/problem_4.py +++ b/problem_4.py @@ -52,8 +52,28 @@ def _flash_attention_forward_causal_kernel( # Implement the logic for the off-diagonal blocks. # This is very similar to the non-causal version from Problem 3. # 1. Load the K and V blocks for the current iteration. + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask = k_offsets[None, :] < SEQ_LEN , other = 0.0) + + v_offsets = start_n + tl.arange(0, BLOCK_N) + v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \ + (v_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask = v_offsets[:, None] < SEQ_LEN, other = 0.0) + # 2. Compute the attention scores (S_ij). + S_ij = tl.dot(q_block, k_block) + ## mask = (start_n + k_offsets[:, None]) <= q_offsets[None, :] # don't quite get this part??? + S_ij *= qk_scale ##+ tl.where(mask, 0, -1.0e6) # 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc). + m_ij = tl.maximum(m_i, tl.max(S_ij, 1)) + scale_factor = tl.exp2(m_i - m_ij) + P_ij = tl.exp2(S_ij - m_ij[:, None]) + l_i = l_i * scale_factor + tl.sum(P_ij, 1) + acc = acc * scale_factor[:, None] + tl.dot(P_ij.to(v_block.type), v_block) + + m_i = m_ij pass # --- END OF STUDENT IMPLEMENTATION --- @@ -64,6 +84,29 @@ def _flash_attention_forward_causal_kernel( for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): # --- STUDENT IMPLEMENTATION REQUIRED HERE --- # Implement the logic for the diagonal blocks, apply the causal mask to S_ij. + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask = k_offsets[None, :] < SEQ_LEN , other = 0.0) + + v_offsets = start_n + tl.arange(0, BLOCK_N) + v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \ + (v_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask = v_offsets[:, None] < SEQ_LEN, other = 0.0) + + # 2. Compute the attention scores (S_ij). + S_ij = tl.dot(q_block, k_block) + mask = (k_offsets[None, :]) <= q_offsets[:, None] # don't quite get this part??? + S_ij = tl.where(mask, S_ij, -float('inf')) + S_ij *= qk_scale + # 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc). + m_ij = tl.maximum(m_i, tl.max(S_ij, 1)) + scale_factor = tl.exp2(m_i - m_ij) # * 1.44269504 + P_ij = tl.exp2(S_ij - m_ij[:, None]) + l_i = l_i * scale_factor + tl.sum(P_ij, 1) + acc = acc * scale_factor[:, None] + tl.dot(P_ij.to(v_block.type), v_block) + + m_i = m_ij pass # --- END OF STUDENT IMPLEMENTATION --- diff --git a/problem_5.py b/problem_5.py index aad8fe1..f7e1aa2 100644 --- a/problem_5.py +++ b/problem_5.py @@ -35,8 +35,8 @@ def _flash_attention_forward_gqa_kernel( # Your goal is to map the current query head (q_head_idx) to its corresponding shared key/value head (kv_head_idx). # 1. Calculate how many query heads are in each group. # 2. Use integer division to find the correct kv_head_idx. - - kv_head_idx = 0 # Placeholder: Replace with your calculation + group_size = N_Q_HEADS // N_KV_HEADS + kv_head_idx = q_head_idx // group_size # Placeholder: Replace with your calculation # --- END OF STUDENT IMPLEMENTATION --- @@ -57,8 +57,28 @@ def _flash_attention_forward_gqa_kernel( for start_n in range(0, q_block_idx * BLOCK_M, BLOCK_N): # --- STUDENT IMPLEMENTATION REQUIRED HERE (Part 2) --- # 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`. + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask = k_offsets[None, :] < SEQ_LEN , other = 0.0) + + v_offsets = start_n + tl.arange(0, BLOCK_N) + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (v_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask = v_offsets[:, None] < SEQ_LEN, other = 0.0) + # 2. Reuse your working implementation for the online softmax update - # from your solution to Problem 4. + S_ij = tl.dot(q_block, k_block) + ## mask = (start_n + k_offsets[:, None]) <= q_offsets[None, :] # don't quite get this part??? + S_ij *= qk_scale ##+ tl.where(mask, 0, -1.0e6) + # 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc). + m_ij = tl.maximum(m_i, tl.max(S_ij, 1)) + scale_factor = tl.exp2(m_i - m_ij) + P_ij = tl.exp2(S_ij - m_ij[:, None]) + l_i = l_i * scale_factor + tl.sum(P_ij, 1) + acc = acc * scale_factor[:, None] + tl.dot(P_ij.to(v_block.type), v_block) + + m_i = m_ij pass # --- END OF STUDENT IMPLEMENTATION --- @@ -67,8 +87,28 @@ def _flash_attention_forward_gqa_kernel( for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): # --- STUDENT IMPLEMENTATION REQUIRED HERE (Part 3) --- # 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`. + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask = k_offsets[None, :] < SEQ_LEN , other = 0.0) + + v_offsets = start_n + tl.arange(0, BLOCK_N) + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (v_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask = v_offsets[:, None] < SEQ_LEN, other = 0.0) # 2. Reuse your working implementation for the masked online softmax - # update from your solution to Problem 4. + S_ij = tl.dot(q_block, k_block) + mask = (k_offsets[None, :]) <= q_offsets[:, None] # don't quite get this part??? + S_ij = tl.where(mask, S_ij, -float('inf')) + S_ij *= qk_scale + # 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc). + m_ij = tl.maximum(m_i, tl.max(S_ij, 1)) + scale_factor = tl.exp2(m_i - m_ij) # * 1.44269504 + P_ij = tl.exp2(S_ij - m_ij[:, None]) + l_i = l_i * scale_factor + tl.sum(P_ij, 1) + acc = acc * scale_factor[:, None] + tl.dot(P_ij.to(v_block.type), v_block) + + m_i = m_ij pass # --- END OF STUDENT IMPLEMENTATION --- diff --git a/problem_6.py b/problem_6.py index f097706..5db749f 100644 --- a/problem_6.py +++ b/problem_6.py @@ -35,8 +35,8 @@ def _flash_attention_forward_swa_kernel( # This problem combines GQA and SWA. First, implement the GQA logic. # 1. Calculate the number of query heads per group. # 2. Determine the correct kv_head_idx for the current q_head_idx. - - kv_head_idx = 0 # Placeholder: Replace with your GQA calculation + group_size = N_Q_HEADS // N_KV_HEADS + kv_head_idx = q_head_idx // group_size # Placeholder: Replace with your GQA calculation # --- END OF GQA IMPLEMENTATION --- @@ -45,8 +45,8 @@ def _flash_attention_forward_swa_kernel( l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - # 3. Load query block - q_offsets = (q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M)) + # 3. Load query block + q_offsets = (q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M)) # * BLOCK_M because q_block_idx is 0-indexed. q_ptrs = Q_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + \ (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :]) q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0) @@ -58,20 +58,74 @@ def _flash_attention_forward_swa_kernel( # The kernel should only attend to the `WINDOW_SIZE` most recent key/value tokens. # 1. Calculate the starting position of the attention window (window_start). # 2. Modify the range of the Phase 1 loop to start from your window_start. - - window_start = 0 # Placeholder: Replace with your SWA calculation + diag_start = q_block_idx * BLOCK_M + window_start = tl.maximum(0, diag_start - WINDOW_SIZE + 1) # Placeholder: Replace with your SWA calculation # --- Phase 1: Off-Diagonal Blocks (within the window) --- - for start_n in range(window_start, q_block_idx * BLOCK_M, BLOCK_N): + for start_n in range(0, q_block_idx * BLOCK_M, BLOCK_N): # STUDENT IMPLEMENTATION REQUIRED (Part 3: SWA Logic) # Hint: You might need to apply the per-element sliding window mask to s_ij. # - A score is invalid if `(query_offset - key_offset) >= WINDOW_SIZE`. + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask = k_offsets[None, :] < SEQ_LEN , other = 0.0) + + v_offsets = start_n + tl.arange(0, BLOCK_N) + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (v_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask = v_offsets[:, None] < SEQ_LEN, other = 0.0) + + # 2. Reuse your working implementation for the online softmax update + S_ij = tl.dot(q_block, k_block) + ## mask = (start_n + k_offsets[:, None]) <= q_offsets[None, :] # don't quite get this part??? + # Applying sliding window mask: + # delta = i - j; valid iff 0 <= delta < WINDOW_SIZE + delta = q_offsets[:, None] - k_offsets[None, :] + sw_mask = (delta >= 0) & (delta < WINDOW_SIZE) + S_ij = tl.where(sw_mask, S_ij, -10000) + + S_ij *= qk_scale ##+ tl.where(mask, 0, -1.0e6) + # 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc). + m_ij = tl.maximum(m_i, tl.max(S_ij, 1)) + scale_factor = tl.exp2(m_i - m_ij) + P_ij = tl.exp2(S_ij - m_ij[:, None]) + l_i = l_i * scale_factor + tl.sum(P_ij, 1) + acc = acc * scale_factor[:, None] + tl.dot(P_ij.to(v_block.type), v_block) + + m_i = m_ij pass # --- Phase 2: Diagonal Blocks --- diag_start = q_block_idx * BLOCK_M for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): # STUDENT IMPLEMENTATION REQUIRED + # 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`. + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask = k_offsets[None, :] < SEQ_LEN , other = 0.0) + + v_offsets = start_n + tl.arange(0, BLOCK_N) + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (v_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask = v_offsets[:, None] < SEQ_LEN, other = 0.0) + # 2. Reuse your working implementation for the masked online softmax + S_ij = tl.dot(q_block, k_block) + # apply sliding window mask + delta = q_offsets[:, None] - k_offsets[None,:] + sw_mask = (delta >= 0) & (delta < WINDOW_SIZE) + S_ij = tl.where(sw_mask, S_ij, -10000) + + S_ij *= qk_scale + # 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc). + m_ij = tl.maximum(m_i, tl.max(S_ij, 1)) + scale_factor = tl.exp2(m_i - m_ij) # * 1.44269504 + P_ij = tl.exp2(S_ij - m_ij[:, None]) + l_i = l_i * scale_factor + tl.sum(P_ij, 1) + acc = acc * scale_factor[:, None] + tl.dot(P_ij.to(v_block.type), v_block) + + m_i = m_ij pass # --- END OF SWA IMPLEMENTATION --- diff --git a/problem_7.py b/problem_7.py index af1558e..afaee55 100644 --- a/problem_7.py +++ b/problem_7.py @@ -3,6 +3,50 @@ import triton.language as tl import math +@triton.jit +def _flash_attention_forward_swa_kernel( + # Pointers to Tensors + Q_ptr, K_ptr, V_ptr, O_ptr, + # Stride information for tensors + q_stride_b, q_stride_h, q_stride_s, + k_stride_b, k_stride_h, k_stride_s, + v_stride_b, v_stride_h, v_stride_s, + # Kernel parameters + softmax_scale, + SEQ_LEN, + N_Q_HEADS, + N_KV_HEADS, + WINDOW_SIZE: tl.constexpr, + SINK_SIZE: tl.constexpr, + # Constexpr tile sizes + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Triton kernel for the forward pass of causal FlashAttention with GQA, Sliding Window Attention, and Attention Sink. + """ + # 1. Identify the block of queries and the batch/head to be processed. + q_block_idx = tl.program_id(axis=0) + batch_head_idx = tl.program_id(axis=1) + + batch_idx = batch_head_idx // N_Q_HEADS + q_head_idx = batch_head_idx % N_Q_HEADS + + # --- GQA Logic: Map Query Head to Shared K/V Head --- + num_groups = N_Q_HEADS // N_KV_HEADS + kv_head_idx = q_head_idx // num_groups + + # 2. Initialize accumulators in SRAM. + m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + import torch +import triton +import triton.language as tl +import math + @triton.jit def _flash_attention_forward_swa_kernel( # Pointers to Tensors @@ -49,15 +93,161 @@ def _flash_attention_forward_swa_kernel( q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0) qk_scale = softmax_scale * 1.44269504 + q_block = q_block.to(tl.float32) + q_start = q_block_idx * BLOCK_M + win_left = q_start - (WINDOW_SIZE - 1) + window_start = tl.maximum(0, win_left) + + diag_start = q_block_idx * BLOCK_M + + # Phase 0: Attetion sink only + for start_n in range(0, SINK_SIZE, BLOCK_N): + #Load K + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # 2. Compute the attention scores (S_ij). + k_block = k_block.to(tl.float32) + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + v_block = v_block.to(tl.float32) + + # Masks + sink_cols = (k_offsets[None, :] < SINK_SIZE) + causal = (q_offsets[:, None] >= k_offsets[None, :]) + valid = (q_offsets[:, None] < SEQ_LEN) & (k_offsets[None, :] < SEQ_LEN) + mask = sink_cols & causal & valid + + s_ij = tl.where(mask, s_ij, -float('inf')) + + # Row has anything valid in this tile? + row_has = tl.max(mask, axis=1) > 0 + + # Online softmax update + m_ij = tl.max(s_ij, axis=1) + # Only update rows that have something valid + m_new = tl.where(row_has, tl.maximum(m_i, m_ij), m_i) + # Calculate scale factor only for rows that have something valid + scale_factor = tl.where(row_has, tl.exp2(m_i - m_new), 1.0) + + # Probabilities only for rows that have something valid + p_ij = tl.where(row_has[:, None], tl.exp2(s_ij - m_new[:, None]), 0.0) + + acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block) + l_i = l_i * scale_factor + tl.sum(p_ij, axis=1) + m_i = m_new + + # Phase 1: Off-Diagonal Blocks (within the window), excl sinks + for start_n in range(window_start, q_block_idx * BLOCK_M, BLOCK_N): + # Load K + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # 2. Compute the attention scores (S_ij). + k_block = k_block.to(tl.float32) + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + v_block = v_block.to(tl.float32) + + # EXCLUDE sinks (already handled in Phase 0) + non_sink = k_offsets[None, :] >= SINK_SIZE + + # Sliding window mask + dist = q_offsets[:, None] - k_offsets[None, :] #(BLOCK_M, BLOCK_N) + window_mask = (dist >= 0) & (dist < WINDOW_SIZE) + + # Validity mask + valid_mask = (q_offsets[:, None] < SEQ_LEN) & (k_offsets[None, :] < SEQ_LEN) + + # Prevent overlap with diagonal tile: + pre_diag_mask = k_offsets[None, :] < diag_start + + # Combine masks + mask = window_mask & valid_mask & pre_diag_mask & non_sink + s_ij = tl.where(mask, s_ij, -float('inf')) + + # Row has anything valid in this tile? + row_has = tl.max(mask, axis=1) > 0 + + # online softmax update + m_ij = tl.max(s_ij, axis=1) + # Only update rows that have something valid + m_new = tl.where(row_has, tl.maximum(m_i, m_ij), m_i) + # Calculate scale factor only for rows that have something valid + scale_factor = tl.where(row_has, tl.exp2(m_i - m_new), 1.0) + + # Probabilities only for rows that have something valid + p_ij = tl.where(row_has[:, None], tl.exp2(s_ij - m_new[:, None]), 0.0) + + acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block) + l_i = l_i * scale_factor + tl.sum(p_ij, axis=1) + m_i = m_new + + # Phase 2: Diagonal Blocks + diag_start = q_block_idx * BLOCK_M + for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): + # Load K + k_offsets = start_n + tl.arange(0, BLOCK_N) # (BLOCK_N,) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # 2. Compute the attention scores (S_ij). + k_block = k_block.to(tl.float32) + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + v_block = v_block.to(tl.float32) + + # NON sink mask + non_sink = k_offsets[None, :] >= SINK_SIZE + + # Sliding window mask + dist = q_offsets[:, None] - k_offsets[None, :] #(BLOCK_M, BLOCK_N) + window_mask = (dist >= 0) & (dist < WINDOW_SIZE) + + # Combine masks + causal = q_offsets[:, None] >= k_offsets[None, :] #Lower triangle true + valid = (q_offsets[:, None] < SEQ_LEN) & (k_offsets[None, :] < SEQ_LEN) + mask = causal & valid & window_mask & non_sink + + # Apply mask BEFORE tile max so future tokens don't affect m_i + s_ij = tl.where(mask, s_ij, -float("inf")) + + # Row has anything valid in this tile? + row_has = tl.max(mask, axis=1) > 0 + + # online softmax update + m_ij = tl.max(s_ij, axis=1) + # Only update rows that have something valid + m_new = tl.where(row_has, tl.maximum(m_i, m_ij), m_i) + # Calculate scale factor only for rows that have something valid + scale_factor = tl.where(row_has, tl.exp2(m_i - m_new), 1.0) + + # Probabilities only for rows that have something valid + p_ij = tl.where(row_has[:, None], tl.exp2(s_ij - m_new[:, None]), 0.0) - # --- STUDENT IMPLEMENTATION REQUIRED HERE --- - # Combine the GQA, SWA, and Sink logic. - # Combine all code from previous problems, and add the sink logic. - # You should have 3 phases: - # 1. Phase 0: Sink blocks that are before the sliding window - # 2. Phase 1: Off-Diagonal Blocks (within the window) - # 3. Phase 2: Diagonal Blocks - pass + acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block) + l_i = l_i * scale_factor + tl.sum(p_ij, axis=1) + m_i = m_new # --- END OF STUDENT IMPLEMENTATION --- # 4. Normalize and write the final output block.