From ea7c03dbbd8c2151c54efd0f5cad3701d1a8bd98 Mon Sep 17 00:00:00 2001 From: Dhruv Date: Sun, 14 Sep 2025 17:54:01 +0530 Subject: [PATCH 1/2] all problems run --- problem_1.py | 42 +++-- problem_2.py | 24 +-- problem_3.py | 18 +- problem_4.py | 79 +++++++- problem_5.py | 91 +++++++++- problem_6.py | 106 ++++++++++- problem_7.py | 161 +++++++++++++++- problem_8.py | 504 +++++++++++++++++++++++++++++++++++++++++++++++++-- problem_9.py | 476 +++++++++++++++++++++++++++++++++++++++++++++--- 9 files changed, 1411 insertions(+), 90 deletions(-) diff --git a/problem_1.py b/problem_1.py index 683be1c..437499b 100644 --- a/problem_1.py +++ b/problem_1.py @@ -22,7 +22,7 @@ def forward(ctx, Q, K, V, is_causal=False): N_K_tiles = math.ceil(N_K / K_TILE_SIZE) # Initialize final output tensors - O_final = torch.zeros_like(Q, dtype=Q.dtype) + O_final = torch.zeros_like(Q, dtype=Q.dtype).to(torch.float32) L_final = torch.zeros((B, H, N_Q), device=Q.device, dtype=torch.float32) scale = 1.0 / math.sqrt(D_H) @@ -38,10 +38,10 @@ def forward(ctx, Q, K, V, is_causal=False): for i in range(N_Q_tiles): q_start = i * Q_TILE_SIZE q_end = min((i + 1) * Q_TILE_SIZE, N_Q) - Q_tile = Q_bh[q_start:q_end, :] + Q_tile = Q_bh[q_start:q_end, :].to(torch.float32) # Initialize accumulators for this query tile - o_i = torch.zeros_like(Q_tile, dtype=Q.dtype) + o_i = torch.zeros_like(Q_tile, dtype=Q.dtype).to(torch.float32) l_i = torch.zeros(q_end - q_start, device=Q.device, dtype=torch.float32) m_i = torch.full((q_end - q_start,), -float('inf'), device=Q.device, dtype=torch.float32) @@ -50,24 +50,38 @@ def forward(ctx, Q, K, V, is_causal=False): k_start = j * K_TILE_SIZE k_end = min((j + 1) * K_TILE_SIZE, N_K) - K_tile = K_bh[k_start:k_end, :] - V_tile = V_bh[k_start:k_end, :] + K_tile = K_bh[k_start:k_end, :].to(torch.float32) + V_tile = V_bh[k_start:k_end, :].to(torch.float32) S_ij = (Q_tile @ K_tile.transpose(-1, -2)) * scale - - # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # print("S_ij.shape", S_ij.shape) # 1. Apply causal masking if is_causal is True. - # + if is_causal: + # Build a causal mask for the current q/k tile based on absolute positions. + # Mask positions where key_index > query_index (i.e., future positions). + q_len = q_end - q_start + k_len = k_end - k_start + q_positions = torch.arange(q_start, q_end, device=Q.device).unsqueeze(1) # (q_len, 1) + k_positions = torch.arange(k_start, k_end, device=Q.device).unsqueeze(0) # (1, k_len) + causal_mask = k_positions > q_positions # (q_len, k_len) True for disallowed attends + S_ij = S_ij.masked_fill(causal_mask, float('-inf')) + # 2. Compute the new running maximum - # + m_new = torch.max(m_i, torch.max(S_ij, dim=1).values) + # 3. Rescale the previous accumulators (o_i, l_i) - # + l_i_rescaled = (l_i*(torch.exp(m_i-m_new))) + mult = torch.exp(m_i - m_new) # shape: (q_len,) + o_i_rescaled = o_i * mult.unsqueeze(-1) # 4. Compute the probabilities for the current tile, P_tilde_ij = exp(S_ij - m_new). - # + # P_tilde_ij = torch.exp(S_ij - m_new).to(torch.float32) + P_tilde_ij = torch.exp(S_ij - m_new.unsqueeze(1)).to(torch.float32) # 5. Accumulate the current tile's contribution to the accumulators to update l_i and o_i - # + l_new = l_i_rescaled + (P_tilde_ij.sum(dim=1)) + o_i = o_i_rescaled + (P_tilde_ij @ V_tile) # 6. Update the running max for the next iteration - + m_i = m_new + l_i = l_new # --- END OF STUDENT IMPLEMENTATION --- # After iterating through all key tiles, normalize the output @@ -82,7 +96,7 @@ def forward(ctx, Q, K, V, is_causal=False): L_final[b, h, q_start:q_end] = L_tile O_final = O_final.to(Q.dtype) - + # adf ctx.save_for_backward(Q, K, V, O_final, L_final) ctx.is_causal = is_causal diff --git a/problem_2.py b/problem_2.py index 011f3eb..da1a6f7 100644 --- a/problem_2.py +++ b/problem_2.py @@ -16,47 +16,47 @@ def weighted_row_sum_kernel( """ # 1. Get the row index for the current program instance. # Hint: Use tl.program_id(axis=0). - row_idx = ... + row_idx = tl.program_id(axis=0) # 2. Create a pointer to the start of the current row in the input tensor X. # Hint: The offset depends on the row index and the number of columns (N_COLS). - row_start_ptr = ... + row_start_ptr = X_ptr + row_idx * N_COLS # 3. Create a pointer for the output vector Y. - output_ptr = ... + output_ptr = Y_ptr + row_idx # 4. Initialize an accumulator for the sum of the products for a block. # This should be a block-sized tensor of zeros. # Hint: Use tl.zeros with shape (BLOCK_SIZE,) and dtype tl.float32. - accumulator = ... + accumulator = tl.zeros([BLOCK_SIZE], dtype=tl.float32) # 5. Iterate over the columns of the row in blocks of BLOCK_SIZE. # Hint: Use a for loop with tl.cdiv(N_COLS, BLOCK_SIZE). - for col_block_start in range(0, ...): + for col_block_start in range(0, tl.cdiv(N_COLS, BLOCK_SIZE)): # - Calculate the offsets for the current block of columns. # Hint: Start from the block's beginning and add tl.arange(0, BLOCK_SIZE). - col_offsets = ... + col_offsets = col_block_start * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # - Create a mask to prevent out-of-bounds memory access for the last block. # Hint: Compare col_offsets with N_COLS. - mask = ... + mask = col_offsets < N_COLS # - Load a block of data from X and W safely using the mask. # Hint: Use tl.load with the appropriate pointers, offsets, and mask. # Use `other=0.0` to handle out-of-bounds elements. - x_chunk = tl.load(...) - w_chunk = tl.load(...) + x_chunk = tl.load(row_start_ptr + col_offsets, mask=mask, other=0.0) + w_chunk = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) # - Compute the element-wise product and add it to the accumulator. - accumulator += ... + accumulator += x_chunk * w_chunk # 6. Reduce the block-sized accumulator to a single scalar value after the loop. # Hint: Use tl.sum(). - final_sum = ... + final_sum = tl.sum(accumulator) # 7. Store the final accumulated sum to the output tensor Y. # Hint: Use tl.store(). - ... + tl.store(output_ptr, final_sum) # --- END OF STUDENT IMPLEMENTATION --- diff --git a/problem_3.py b/problem_3.py index 43f134a..66fc5df 100644 --- a/problem_3.py +++ b/problem_3.py @@ -62,14 +62,28 @@ def _flash_attention_forward_kernel( v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) # --- STUDENT IMPLEMENTATION REQUIRED HERE --- - # Implement the online softmax update logic. + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + # 1. Find the new running maximum (`m_new`). + m_new = tl.maximum(m_i, tl.max(s_ij, axis=1)) # 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`). + l_i_rescaled = (l_i*(tl.exp2(m_i-m_new))) + mult = tl.exp2(m_i - m_new) # shape: (q_len,) + acc_rescaled = acc * mult[:, None] # 3. Compute the attention probabilities for the current tile (`p_ij`). + P_tilde_ij = tl.exp2(s_ij - m_new[:, None]) # 4. Update the accumulator `acc` using `p_ij` and `v_block`. + l_new = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij ,tl.cast(v_block, tl.float32)) # 5. Update the denominator `l_i`. + l_i = l_new # 6. Update the running maximum `m_i` for the next iteration. - pass + m_i = m_new # --- END OF STUDENT IMPLEMENTATION --- diff --git a/problem_4.py b/problem_4.py index 4faf01b..d1d05dd 100644 --- a/problem_4.py +++ b/problem_4.py @@ -54,7 +54,44 @@ def _flash_attention_forward_causal_kernel( # 1. Load the K and V blocks for the current iteration. # 2. Compute the attention scores (S_ij). # 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc). - pass + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Compute attention scores S_ij = Q_i * K_j^T + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_new = tl.maximum(m_i, tl.max(s_ij, axis=1)) + # 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`). + l_i_rescaled = (l_i*(tl.exp2(m_i-m_new))) + mult = tl.exp2(m_i - m_new) # shape: (q_len,) + acc_rescaled = acc * mult[:, None] + # 3. Compute the attention probabilities for the current tile (`p_ij`). + P_tilde_ij = tl.exp2(s_ij - m_new[:, None]) + # 4. Update the accumulator `acc` using `p_ij` and `v_block`. + l_new = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij ,tl.cast(v_block, tl.float32)) + # 5. Update the denominator `l_i`. + l_i = l_new + # 6. Update the running maximum `m_i` for the next iteration. + m_i = m_new + # --- END OF STUDENT IMPLEMENTATION --- # --- END OF STUDENT IMPLEMENTATION --- @@ -64,7 +101,45 @@ def _flash_attention_forward_causal_kernel( for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): # --- STUDENT IMPLEMENTATION REQUIRED HERE --- # Implement the logic for the diagonal blocks, apply the causal mask to S_ij. - pass + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Compute attention scores S_ij = Q_i * K_j^T + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + mask = k_offsets[None, :] <= q_offsets[:, None] + s_ij = tl.where(mask, s_ij, -float('inf')) + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_new = tl.maximum(m_i, tl.max(s_ij, axis=1)) + # 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`). + l_i_rescaled = (l_i*(tl.exp2(m_i-m_new))) + mult = tl.exp2(m_i - m_new) # shape: (q_len,) + acc_rescaled = acc * mult[:, None] + # 3. Compute the attention probabilities for the current tile (`p_ij`). + P_tilde_ij = tl.exp2(s_ij - m_new[:, None]) + # 4. Update the accumulator `acc` using `p_ij` and `v_block`. + l_new = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij ,tl.cast(v_block, tl.float32)) + # 5. Update the denominator `l_i`. + l_i = l_new + # 6. Update the running maximum `m_i` for the next iteration. + m_i = m_new + # --- END OF STUDENT IMPLEMENTATION --- # --- END OF STUDENT IMPLEMENTATION --- diff --git a/problem_5.py b/problem_5.py index aad8fe1..7481ae3 100644 --- a/problem_5.py +++ b/problem_5.py @@ -34,11 +34,11 @@ def _flash_attention_forward_gqa_kernel( # --- STUDENT IMPLEMENTATION REQUIRED HERE (Part 1) --- # Your goal is to map the current query head (q_head_idx) to its corresponding shared key/value head (kv_head_idx). # 1. Calculate how many query heads are in each group. + n_q_head_per_group = N_Q_HEADS // N_KV_HEADS # 2. Use integer division to find the correct kv_head_idx. - - kv_head_idx = 0 # Placeholder: Replace with your calculation + kv_head_idx = q_head_idx // n_q_head_per_group # Placeholder: Replace with your calculation # --- END OF STUDENT IMPLEMENTATION --- - + #(dhruv) this solution passes all the testa cses but ideally it ahould have edge case handling for when nq_heads is not completely divisible by nkv_heads # 2. Initialize accumulators in SRAM. m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32) @@ -59,7 +59,50 @@ def _flash_attention_forward_gqa_kernel( # 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`. # 2. Reuse your working implementation for the online softmax update # from your solution to Problem 4. - pass + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the logic for the off-diagonal blocks. + # This is very similar to the non-causal version from Problem 3. + # 1. Load the K and V blocks for the current iteration. + # 2. Compute the attention scores (S_ij). + # 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc). + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Compute attention scores S_ij = Q_i * K_j^T + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_new = tl.maximum(m_i, tl.max(s_ij, axis=1)) + # 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`). + l_i_rescaled = (l_i*(tl.exp2(m_i-m_new))) + mult = tl.exp2(m_i - m_new) # shape: (q_len,) + acc_rescaled = acc * mult[:, None] + # 3. Compute the attention probabilities for the current tile (`p_ij`). + P_tilde_ij = tl.exp2(s_ij - m_new[:, None]) + # 4. Update the accumulator `acc` using `p_ij` and `v_block`. + l_new = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij ,tl.cast(v_block, tl.float32)) + # 5. Update the denominator `l_i`. + l_i = l_new + # 6. Update the running maximum `m_i` for the next iteration. + m_i = m_new + # --- END OF STUDENT IMPLEMENTATION --- # --- END OF STUDENT IMPLEMENTATION --- # --- Phase 2: Diagonal Blocks --- @@ -69,7 +112,45 @@ def _flash_attention_forward_gqa_kernel( # 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`. # 2. Reuse your working implementation for the masked online softmax # update from your solution to Problem 4. - pass + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Compute attention scores S_ij = Q_i * K_j^T + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + mask = k_offsets[None, :] <= q_offsets[:, None] + s_ij = tl.where(mask, s_ij, -float('inf')) + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_new = tl.maximum(m_i, tl.max(s_ij, axis=1)) + # 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`). + l_i_rescaled = (l_i*(tl.exp2(m_i-m_new))) + mult = tl.exp2(m_i - m_new) # shape: (q_len,) + acc_rescaled = acc * mult[:, None] + # 3. Compute the attention probabilities for the current tile (`p_ij`). + P_tilde_ij = tl.exp2(s_ij - m_new[:, None]) + # 4. Update the accumulator `acc` using `p_ij` and `v_block`. + l_new = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij ,tl.cast(v_block, tl.float32)) + # 5. Update the denominator `l_i`. + l_i = l_new + # 6. Update the running maximum `m_i` for the next iteration. + m_i = m_new + # --- END OF STUDENT IMPLEMENTATION --- # --- END OF STUDENT IMPLEMENTATION --- # 4. Normalize and write the final output block. diff --git a/problem_6.py b/problem_6.py index f097706..84809b6 100644 --- a/problem_6.py +++ b/problem_6.py @@ -35,8 +35,9 @@ def _flash_attention_forward_swa_kernel( # This problem combines GQA and SWA. First, implement the GQA logic. # 1. Calculate the number of query heads per group. # 2. Determine the correct kv_head_idx for the current q_head_idx. - - kv_head_idx = 0 # Placeholder: Replace with your GQA calculation + n_q_head_per_group = N_Q_HEADS // N_KV_HEADS + # 2. Use integer division to find the correct kv_head_idx. + kv_head_idx = q_head_idx // n_q_head_per_group # --- END OF GQA IMPLEMENTATION --- @@ -50,6 +51,7 @@ def _flash_attention_forward_swa_kernel( q_ptrs = Q_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + \ (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :]) q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0) + q_block = tl.cast(q_block, tl.float32) qk_scale = softmax_scale * 1.44269504 @@ -59,21 +61,111 @@ def _flash_attention_forward_swa_kernel( # 1. Calculate the starting position of the attention window (window_start). # 2. Modify the range of the Phase 1 loop to start from your window_start. - window_start = 0 # Placeholder: Replace with your SWA calculation + window_start = tl.max(0, q_block_idx * BLOCK_M - (WINDOW_SIZE )) # --- Phase 1: Off-Diagonal Blocks (within the window) --- for start_n in range(window_start, q_block_idx * BLOCK_M, BLOCK_N): # STUDENT IMPLEMENTATION REQUIRED (Part 3: SWA Logic) # Hint: You might need to apply the per-element sliding window mask to s_ij. # - A score is invalid if `(query_offset - key_offset) >= WINDOW_SIZE`. - pass + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + k_block = tl.cast(k_block, tl.float32) + + # Compute attention scores S_ij = Q_i * K_j^T + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + valid = (q_offsets[:, None] - k_offsets[None, :]) <= (WINDOW_SIZE - 1) + valid = valid & (k_offsets[None, :] < SEQ_LEN) + s_ij = tl.where(valid, s_ij, -float('inf')) + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + v_block = tl.cast(v_block, tl.float32) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) + # 2. Rescale accumulators; guard all-masked tiles to avoid NaNs. + no_contrib = m_new == -float('inf') + alpha = tl.where(no_contrib, 1.0, tl.exp2(m_i - m_new)) + acc_rescaled = acc * alpha[:, None] + l_i_rescaled = l_i * alpha + # 3. Compute probabilities safely. + s_shifted = s_ij - m_new[:, None] + s_shifted = tl.where(no_contrib[:, None], -float('inf'), s_shifted) + P_tilde_ij = tl.exp2(s_shifted) + # 4. Update accumulators. + l_i = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij, v_block) + # 5. Update running maximum for next iteration. + m_i = m_new + # --- END OF STUDENT IMPLEMENTATION --- + # --- END OF STUDENT IMPLEMENTATION --- # --- Phase 2: Diagonal Blocks --- diag_start = q_block_idx * BLOCK_M for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): - # STUDENT IMPLEMENTATION REQUIRED - pass - # --- END OF SWA IMPLEMENTATION --- + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + k_block = tl.cast(k_block, tl.float32) + + # Compute attention scores S_ij = Q_i * K_j^T + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + # Causal + sliding window mask within the diagonal block + valid = (k_offsets[None, :] <= q_offsets[:, None]) & \ + ((q_offsets[:, None] - k_offsets[None, :]) <= (WINDOW_SIZE - 1)) & \ + (k_offsets[None, :] < SEQ_LEN) + s_ij = tl.where(valid, s_ij, -float('inf')) + # mask = q_offsets[:, None]-k_offsets[None, :] >= WINDOW_SIZE + # s_ij = tl.where(mask, s_ij, -float('inf')) + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + v_block = tl.cast(v_block, tl.float32) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) + # 2. Rescale accumulators; guard all-masked tiles to avoid NaNs. + no_contrib = m_new == -float('inf') + alpha = tl.where(no_contrib, 1.0, tl.exp2(m_i - m_new)) + acc_rescaled = acc * alpha[:, None] + l_i_rescaled = l_i * alpha + # 3. Compute probabilities safely. + s_shifted = s_ij - m_new[:, None] + s_shifted = tl.where(no_contrib[:, None], -float('inf'), s_shifted) + P_tilde_ij = tl.exp2(s_shifted) + # 4. Update accumulators. + l_i = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij, v_block) + # 5. Update running maximum for next iteration. + m_i = m_new + # --- END OF STUDENT IMPLEMENTATION --- + # --- END OF STUDENT IMPLEMENTATION --- # 4. Normalize and write the final output block. diff --git a/problem_7.py b/problem_7.py index af1558e..1712540 100644 --- a/problem_7.py +++ b/problem_7.py @@ -47,7 +47,7 @@ def _flash_attention_forward_swa_kernel( q_ptrs = Q_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + \ (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :]) q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0) - + q_block = tl.cast(q_block, tl.float32) qk_scale = softmax_scale * 1.44269504 # --- STUDENT IMPLEMENTATION REQUIRED HERE --- @@ -57,8 +57,163 @@ def _flash_attention_forward_swa_kernel( # 1. Phase 0: Sink blocks that are before the sliding window # 2. Phase 1: Off-Diagonal Blocks (within the window) # 3. Phase 2: Diagonal Blocks - pass - # --- END OF STUDENT IMPLEMENTATION --- + window_start = max(0, q_block_idx * BLOCK_M - WINDOW_SIZE) + # --- Phase 0: Sink blocks that are before the sliding window --- + sink_end = min(SINK_SIZE, window_start) + for start_n in range(0, sink_end, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + k_block = tl.cast(k_block, tl.float32) + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + # Attend to sink tokens only (indices < SINK_SIZE), while preserving causality j <= i + valid = (k_offsets[None, :] < SINK_SIZE) + s_ij = tl.where(valid, s_ij, -float('inf')) + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + v_block = tl.cast(v_block, tl.float32) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) + # 2. Rescale accumulators; guard all-masked tiles to avoid NaNs. + no_contrib = m_new == -float('inf') + alpha = tl.where(no_contrib, 1.0, tl.exp2(m_i - m_new)) + acc_rescaled = acc * alpha[:, None] + l_i_rescaled = l_i * alpha + # 3. Compute probabilities safely. + s_shifted = s_ij - m_new[:, None] + s_shifted = tl.where(no_contrib[:, None], -float('inf'), s_shifted) + P_tilde_ij = tl.exp2(s_shifted) + # 4. Update accumulators. + l_i = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij, v_block) + # 5. Update running maximum for next iteration. + m_i = m_new + # --- STUDENT IMPLEMENTATION REQUIRED (Part 2: SWA Logic) --- + # Now, implement the "sliding window" by changing the loop bounds. + # The kernel should only attend to the `WINDOW_SIZE` most recent key/value tokens. + # 1. Calculate the starting position of the attention window (window_start). + # 2. Modify the range of the Phase 1 loop to start from your window_start. + + + + # --- Phase 1: Off-Diagonal Blocks (within the window) --- + for start_n in range(window_start, q_block_idx * BLOCK_M, BLOCK_N): + # STUDENT IMPLEMENTATION REQUIRED (Part 3: SWA Logic) + # Hint: You might need to apply the per-element sliding window mask to s_ij. + # - A score is invalid if `(query_offset - key_offset) >= WINDOW_SIZE`. + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + k_block = tl.cast(k_block, tl.float32) + + # Compute attention scores S_ij = Q_i * K_j^T + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + valid = (q_offsets[:, None] - k_offsets[None, :]) < (WINDOW_SIZE) + valid = valid | (k_offsets[None, :] < SINK_SIZE) + s_ij = tl.where(valid, s_ij, -float('inf')) + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + v_block = tl.cast(v_block, tl.float32) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) + # 2. Rescale accumulators; guard all-masked tiles to avoid NaNs. + no_contrib = m_new == -float('inf') + alpha = tl.where(no_contrib, 1.0, tl.exp2(m_i - m_new)) + acc_rescaled = acc * alpha[:, None] + l_i_rescaled = l_i * alpha + # 3. Compute probabilities safely. + s_shifted = s_ij - m_new[:, None] + s_shifted = tl.where(no_contrib[:, None], -float('inf'), s_shifted) + P_tilde_ij = tl.exp2(s_shifted) + # 4. Update accumulators. + l_i = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij, v_block) + # 5. Update running maximum for next iteration. + m_i = m_new + # --- END OF STUDENT IMPLEMENTATION --- + # --- END OF STUDENT IMPLEMENTATION --- + + # --- Phase 2: Diagonal Blocks --- + diag_start = q_block_idx * BLOCK_M + for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + k_block = tl.cast(k_block, tl.float32) + + # Compute attention scores S_ij = Q_i * K_j^T + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + # Causal + sliding window mask within the diagonal block + valid = (k_offsets[None, :] <= q_offsets[:, None]) & \ + ((q_offsets[:, None] - k_offsets[None, :]) <(WINDOW_SIZE )) & \ + (k_offsets[None, :] < SEQ_LEN) + valid = valid | ((k_offsets[None, :] < SINK_SIZE) & (k_offsets[None, :] <= q_offsets[:, None]) &(k_offsets[None, :] < SEQ_LEN)) + s_ij = tl.where(valid, s_ij, -float('inf')) + # mask = q_offsets[:, None]-k_offsets[None, :] >= WINDOW_SIZE + # s_ij = tl.where(mask, s_ij, -float('inf')) + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + v_block = tl.cast(v_block, tl.float32) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) + # 2. Rescale accumulators; guard all-masked tiles to avoid NaNs. + no_contrib = m_new == -float('inf') + alpha = tl.where(no_contrib, 1.0, tl.exp2(m_i - m_new)) + acc_rescaled = acc * alpha[:, None] + l_i_rescaled = l_i * alpha + # 3. Compute probabilities safely. + s_shifted = s_ij - m_new[:, None] + s_shifted = tl.where(no_contrib[:, None], -float('inf'), s_shifted) + P_tilde_ij = tl.exp2(s_shifted) + # 4. Update accumulators. + l_i = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij, v_block) + # 5. Update running maximum for next iteration. + m_i = m_new + # --- END OF STUDENT IMPLEMENTATION --- + # --- END OF STUDENT IMPLEMENTATION --- # 4. Normalize and write the final output block. l_i_safe = tl.where(l_i == 0, 1.0, l_i) diff --git a/problem_8.py b/problem_8.py index 8e63578..6150d0c 100644 --- a/problem_8.py +++ b/problem_8.py @@ -4,11 +4,364 @@ import triton.language as tl import math from typing import Optional +@triton.jit +def _flash_attention_forward_gqa_kernel( + # Pointers to Tensors + Q_ptr, K_ptr, V_ptr, O_ptr, M_ptr, L_ptr, + # Stride information for tensors + q_stride_b, q_stride_h, q_stride_s, + k_stride_b, k_stride_h, k_stride_s, + v_stride_b, v_stride_h, v_stride_s, + m_stride_b, m_stride_h, m_stride_s, + l_stride_b, l_stride_h, l_stride_s, + # Kernel parameters + softmax_scale, + SEQ_LEN, + N_Q_HEADS, + N_KV_HEADS, + # Constexpr tile sizes + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Triton kernel template for the forward pass of causal FlashAttention with GQA. + """ + # 1. Identify the block of queries and the batch/head to be processed. + q_block_idx = tl.program_id(axis=0) + batch_head_idx = tl.program_id(axis=1) + + batch_idx = batch_head_idx // N_Q_HEADS + q_head_idx = batch_head_idx % N_Q_HEADS + + # --- STUDENT IMPLEMENTATION REQUIRED HERE (Part 1) --- + # Your goal is to map the current query head (q_head_idx) to its corresponding shared key/value head (kv_head_idx). + # 1. Calculate how many query heads are in each group. + n_q_head_per_group = N_Q_HEADS // N_KV_HEADS + # 2. Use integer division to find the correct kv_head_idx and clamp to valid range. + kv_head_idx = q_head_idx // n_q_head_per_group + kv_head_idx = tl.minimum(kv_head_idx, N_KV_HEADS - 1) + # --- END OF STUDENT IMPLEMENTATION --- + #(dhruv) this solution passes all the testa cses but ideally it ahould have edge case handling for when nq_heads is not completely divisible by nkv_heads + + # 2. Initialize accumulators in SRAM. + m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + # 3. Load the block of queries (Q_i). + q_offsets = (q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M)) + q_ptrs = Q_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + \ + (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0) + q_block = tl.cast(q_block, tl.float32) + + qk_scale = softmax_scale * 1.44269504 + + # --- Phase 1: Off-Diagonal Blocks --- + for start_n in range(0, q_block_idx * BLOCK_M, BLOCK_N): + # --- STUDENT IMPLEMENTATION REQUIRED HERE (Part 2) --- + # 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`. + # 2. Reuse your working implementation for the online softmax update + # from your solution to Problem 4. + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the logic for the off-diagonal blocks. + # This is very similar to the non-causal version from Problem 3. + # 1. Load the K and V blocks for the current iteration. + # 2. Compute the attention scores (S_ij). + # 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc). + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + k_block = tl.cast(k_block, tl.float32) + + # Compute attention scores S_ij = Q_i * K_j^T + s_ij = tl.dot(q_block, k_block, allow_tf32=False) + s_ij *= qk_scale + + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + v_block = tl.cast(v_block, tl.float32) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_new = tl.maximum(m_i, tl.max(s_ij, axis=1)) + # 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`). + l_i_rescaled = (l_i*(tl.exp2(m_i-m_new))) + mult = tl.exp2(m_i - m_new) # shape: (q_len,) + acc_rescaled = acc * mult[:, None] + # 3. Compute the attention probabilities for the current tile (`p_ij`). + P_tilde_ij = tl.exp2(s_ij - m_new[:, None]) + # 4. Update the accumulator `acc` using `p_ij` and `v_block`. + l_new = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij ,tl.cast(v_block, tl.float32), allow_tf32=False) + # 5. Update the denominator `l_i`. + l_i = l_new + # 6. Update the running maximum `m_i` for the next iteration. + m_i = m_new + # --- END OF STUDENT IMPLEMENTATION --- + # --- END OF STUDENT IMPLEMENTATION --- + + # --- Phase 2: Diagonal Blocks --- + diag_start = q_block_idx * BLOCK_M + for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): + # --- STUDENT IMPLEMENTATION REQUIRED HERE (Part 3) --- + # 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`. + # 2. Reuse your working implementation for the masked online softmax + # update from your solution to Problem 4. + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + k_block = tl.cast(k_block, tl.float32) + + # Compute attention scores S_ij = Q_i * K_j^T + s_ij = tl.dot(q_block, k_block, allow_tf32=False) + s_ij *= qk_scale + mask = k_offsets[None, :] <= q_offsets[:, None] + s_ij = tl.where(mask, s_ij, -float('inf')) + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + v_block = tl.cast(v_block, tl.float32) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_new = tl.maximum(m_i, tl.max(s_ij, axis=1)) + # 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`). + l_i_rescaled = (l_i*(tl.exp2(m_i-m_new))) + mult = tl.exp2(m_i - m_new) # shape: (q_len,) + acc_rescaled = acc * mult[:, None] + # 3. Compute the attention probabilities for the current tile (`p_ij`). + P_tilde_ij = tl.exp2(s_ij - m_new[:, None]) + # 4. Update the accumulator `acc` using `p_ij` and `v_block`. + l_new = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij ,tl.cast(v_block, tl.float32), allow_tf32=False) + # 5. Update the denominator `l_i`. + l_i = l_new + # 6. Update the running maximum `m_i` for the next iteration. + m_i = m_new + # --- END OF STUDENT IMPLEMENTATION --- + # --- END OF STUDENT IMPLEMENTATION --- + + # 4. Normalize and write the final output block. + l_i_safe = tl.maximum(l_i[:, None], 1e-12) + acc = acc / l_i_safe + + o_ptrs = O_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + \ + (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + + tl.store(o_ptrs, acc.to(O_ptr.dtype.element_ty), mask=q_offsets[:, None] < SEQ_LEN) + temp = (q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M)) + # Use (B,H,S) strides when storing per-row stats + m_ptrs = M_ptr + batch_idx * m_stride_b + q_head_idx * m_stride_h + temp * m_stride_s + tl.store(m_ptrs, m_i.to(M_ptr.dtype.element_ty), mask=(temp < SEQ_LEN)) + l_ptrs = L_ptr + batch_idx * l_stride_b + q_head_idx * l_stride_h + temp * l_stride_s + tl.store(l_ptrs, l_i.to(L_ptr.dtype.element_ty), mask=(temp < SEQ_LEN)) + + + +@triton.jit +def _flash_attention_backward_gqa_kernel( + # Inputs + Q_ptr, K_ptr, V_ptr, O_ptr, M_ptr, L_ptr, dO_ptr, + # Outputs (grads) + dQ_ptr, dK_ptr, dV_ptr, + # Strides for Q/K/V/O (assumed same layout for Q, O, dO, dQ) + q_stride_b, q_stride_h, q_stride_s, + k_stride_b, k_stride_h, k_stride_s, + v_stride_b, v_stride_h, v_stride_s, + # Strides for M/L (B,H,S) + m_stride_b, m_stride_h, m_stride_s, + l_stride_b, l_stride_h, l_stride_s, + # Kernel parameters + softmax_scale, + SEQ_LEN, + N_Q_HEADS, + N_KV_HEADS, + # Constexpr tile sizes + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + LOG2E = 1.4426950408889634 + qk_scale2 = softmax_scale * LOG2E + + # Program ids + q_block_idx = tl.program_id(axis=0) + bh_idx = tl.program_id(axis=1) + batch_idx = bh_idx // N_Q_HEADS + q_head_idx = bh_idx % N_Q_HEADS + + # GQA mapping: map q_head to kv_head + num_groups = N_Q_HEADS // N_KV_HEADS + kv_head_idx = q_head_idx // num_groups + + # Offsets and pointers for this query block + q_offsets = q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M) + hd = tl.arange(0, HEAD_DIM) + q_ptrs = Q_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + (q_offsets[:, None] * q_stride_s + hd[None, :]) + o_ptrs = O_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + (q_offsets[:, None] * q_stride_s + hd[None, :]) + do_ptrs = dO_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + (q_offsets[:, None] * q_stride_s + hd[None, :]) + + q_mask = q_offsets[:, None] < SEQ_LEN + q_block = tl.load(q_ptrs, mask=q_mask, other=0.0) + o_block = tl.load(o_ptrs, mask=q_mask, other=0.0) + do_block = tl.load(do_ptrs, mask=q_mask, other=0.0) + q_block = tl.cast(q_block, tl.float32) + o_block = tl.cast(o_block, tl.float32) + do_block = tl.cast(do_block, tl.float32) + # Load stored per-row softmax stats + m_ptrs = M_ptr + batch_idx * m_stride_b + q_head_idx * m_stride_h + q_offsets * m_stride_s + l_ptrs = L_ptr + batch_idx * l_stride_b + q_head_idx * l_stride_h + q_offsets * l_stride_s + m_i = tl.load(m_ptrs, mask=q_offsets < SEQ_LEN, other=0.0) + l_i = tl.load(l_ptrs, mask=q_offsets < SEQ_LEN, other=1.0) + l_i = tl.maximum(l_i, 1e-12) + + # delta = sum(dO * O) per row + delta = tl.sum(do_block * o_block, axis=1) # (BLOCK_M) + + # Accumulator for dQ + dQ_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + # --- Phase 1: Off-Diagonal Blocks --- + for start_n in range(0, q_block_idx * BLOCK_M, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_mask = k_offsets < SEQ_LEN + + # Load K and V tiles in both orientations as needed + k_cols_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_rows_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + (k_offsets[:, None] * k_stride_s + hd[None, :]) + v_rows_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + (k_offsets[:, None] * v_stride_s + hd[None, :]) + v_cols_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + (k_offsets[None, :] * v_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_cols = tl.load(k_cols_ptrs, mask=k_mask[None, :], other=0.0) # (D, N) + k_rows = tl.load(k_rows_ptrs, mask=k_mask[:, None], other=0.0) # (N, D) + v_rows = tl.load(v_rows_ptrs, mask=k_mask[:, None], other=0.0) # (N, D) + v_cols = tl.load(v_cols_ptrs, mask=k_mask[None, :], other=0.0) # (D, N) + # Cast tiles to fp32 for matmuls + k_cols = tl.cast(k_cols, tl.float32) + k_rows = tl.cast(k_rows, tl.float32) + v_rows = tl.cast(v_rows, tl.float32) + v_cols = tl.cast(v_cols, tl.float32) + + # Scores in base-2 domain and probabilities + q_f32 = tl.cast(q_block, tl.float32) + k_cols = tl.cast(k_cols, tl.float32) + k_rows = tl.cast(k_rows, tl.float32) + v_cols = tl.cast(v_cols, tl.float32) + v_rows = tl.cast(v_rows, tl.float32) + s2 = tl.dot(q_f32, k_cols, allow_tf32=False) * qk_scale2 # (M, N) + s2 = tl.where((k_offsets[None, :] < SEQ_LEN), s2, -float('inf')) + p_tilde = tl.exp2(s2 - m_i[:, None]) + P = p_tilde / l_i[:, None] + + # dV partial via matmul: P^T @ dO => (N,D) + do_f32 = tl.cast(do_block, tl.float32) + dv_partial = tl.dot(tl.trans(P), do_f32, allow_tf32=False) + + # t_block = dO @ V^T using V in (D, N) + t_block = tl.dot(do_f32, v_cols, allow_tf32=False) + + # dS = P * (t_block - delta[:, None]) + dS = P * (t_block - delta[:, None]) + + # dQ += dS @ K_rows, scaled by softmax_scale + dQ_acc += tl.dot(dS, k_rows, allow_tf32=False) * softmax_scale + + # dK partial via matmul: dS^T @ Q, scaled by softmax_scale + q_f32 = tl.cast(q_block, tl.float32) + dk_partial = tl.dot(tl.trans(dS), q_f32, allow_tf32=False) * softmax_scale + + # Atomic add into global dV and dK + dv_ptrs = dV_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + (k_offsets[:, None] * v_stride_s + hd[None, :]) + dk_ptrs = dK_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + (k_offsets[:, None] * k_stride_s + hd[None, :]) + tl.atomic_add(dv_ptrs, dv_partial, mask=k_mask[:, None]) + tl.atomic_add(dk_ptrs, dk_partial, mask=k_mask[:, None]) + + # --- Phase 2: Diagonal Blocks --- + diag_start = q_block_idx * BLOCK_M + for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_mask = k_offsets < SEQ_LEN + + k_cols_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_rows_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + (k_offsets[:, None] * k_stride_s + hd[None, :]) + v_rows_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + (k_offsets[:, None] * v_stride_s + hd[None, :]) + v_cols_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + (k_offsets[None, :] * v_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_cols = tl.load(k_cols_ptrs, mask=k_mask[None, :], other=0.0) + k_rows = tl.load(k_rows_ptrs, mask=k_mask[:, None], other=0.0) + v_rows = tl.load(v_rows_ptrs, mask=k_mask[:, None], other=0.0) + v_cols = tl.load(v_cols_ptrs, mask=k_mask[None, :], other=0.0) + k_cols = tl.cast(k_cols, tl.float32) + k_rows = tl.cast(k_rows, tl.float32) + v_rows = tl.cast(v_rows, tl.float32) + v_cols = tl.cast(v_cols, tl.float32) + + # Scores with causal mask inside diagonal tile + q_f32 = tl.cast(q_block, tl.float32) + k_cols = tl.cast(k_cols, tl.float32) + k_rows = tl.cast(k_rows, tl.float32) + v_cols = tl.cast(v_cols, tl.float32) + v_rows = tl.cast(v_rows, tl.float32) + + s2 = tl.dot(q_f32, k_cols, allow_tf32=False) * qk_scale2 + causal = k_offsets[None, :] <= q_offsets[:, None] + valid = causal & (k_offsets[None, :] < SEQ_LEN) + s2 = tl.where(valid, s2, -float('inf')) + p_tilde = tl.exp2(s2 - m_i[:, None]) + P = p_tilde / l_i[:, None] + + # dV partial P^T @ dO + do_f32 = tl.cast(do_block, tl.float32) + dv_partial = tl.dot(tl.trans(P), do_f32, allow_tf32=False) + + # t_block = dO @ V^T + t_block = tl.dot(do_f32, v_cols, allow_tf32=False) + t_block = tl.where(valid, t_block, 0.0) + + dS = P * (t_block - delta[:, None]) + dS = tl.where(valid, dS, 0.0) + + dQ_acc += tl.dot(dS, k_rows, allow_tf32=False) * softmax_scale + + q_f32 = tl.cast(q_block, tl.float32) + dk_partial = tl.dot(tl.trans(dS), q_f32, allow_tf32=False) * softmax_scale + + dv_ptrs = dV_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + (k_offsets[:, None] * v_stride_s + hd[None, :]) + dk_ptrs = dK_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + (k_offsets[:, None] * k_stride_s + hd[None, :]) + tl.atomic_add(dv_ptrs, dv_partial, mask=k_mask[:, None]) + tl.atomic_add(dk_ptrs, dk_partial, mask=k_mask[:, None]) + + # Write dQ for this block + dQ_ptrs = dQ_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + (q_offsets[:, None] * q_stride_s + hd[None, :]) + tl.store(dQ_ptrs, dQ_acc, mask=q_mask) + class FlashAttention2Function(torch.autograd.Function): """ Triton implementation of FlashAttention-2, supports causal attention and GQA. """ + + + @staticmethod def forward(ctx, q, k, v, is_causal=True, softmax_scale: Optional[float] = None): batch, n_heads, seq_len, head_dim = q.shape @@ -22,13 +375,28 @@ def forward(ctx, q, k, v, is_causal=True, softmax_scale: Optional[float] = None) o = torch.empty_like(q) M = torch.empty((batch, n_heads, seq_len), device=q.device, dtype=torch.float32) + L = torch.empty((batch, n_heads, seq_len), device=q.device, dtype=torch.float32) BLOCK_M, BLOCK_N = 128, 64 grid = (triton.cdiv(seq_len, BLOCK_M), batch * n_heads) - - # TODO: Add your forward kernel here + n_q_heads = n_heads + _flash_attention_forward_gqa_kernel[grid]( + q, k, v, o, M, L, + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + M.stride(0), M.stride(1), M.stride(2), + L.stride(0), L.stride(1), L.stride(2), + softmax_scale, + seq_len, + n_q_heads, + n_kv_heads, + HEAD_DIM=head_dim, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) - ctx.save_for_backward(q, k, v, o, M) + ctx.save_for_backward(q, k, v, o, M, L) ctx.softmax_scale = softmax_scale ctx.num_heads = n_heads ctx.num_kv_heads = n_kv_heads @@ -36,24 +404,126 @@ def forward(ctx, q, k, v, is_causal=True, softmax_scale: Optional[float] = None) @staticmethod def backward(ctx, do): - q, k, v, o, M = ctx.saved_tensors + q, k, v, o, M, L = ctx.saved_tensors batch, n_heads, seq_len, head_dim = q.shape n_kv_heads = ctx.num_kv_heads + # Allocate fp32 gradient buffers for safe atomic_add in Triton + dq_f = torch.zeros_like(q, dtype=torch.float32) + dk_rep_f = torch.zeros_like(k, dtype=torch.float32) + dv_rep_f = torch.zeros_like(v, dtype=torch.float32) - dq = torch.empty_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - # [OPTIONAL BONUS] STUDENT IMPLEMENTATION REQUIRED - # Implement the Triton backward kernel for GQA from scratch. - # You should: - # 1. Precompute delta = sum(dO * O) - # 2. Recompute attention probabilities P = softmax(QK^T) - # 3. Use delta + dO to accumulate gradients for dq, dk, dv - # 4. Respect GQA mapping and causal mask - - return dq, dk.to(k.dtype), dv.to(v.dtype), None, None + # Implement backward using a numerically stable streaming algorithm with block-wise keys. + # This avoids allocating full (seq_len x seq_len) attention tensors. + softmax_scale = ctx.softmax_scale + BLOCK_M, BLOCK_N = 128, 64 + grid = (triton.cdiv(seq_len, BLOCK_M), batch * n_heads) + n_q_heads = n_heads + _flash_attention_backward_gqa_kernel[grid]( + q, k, v, o, M, L, do, + dq_f, dk_rep_f, dv_rep_f, + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + M.stride(0), M.stride(1), M.stride(2), + L.stride(0), L.stride(1), L.stride(2), + softmax_scale, + seq_len, + n_q_heads, + n_kv_heads, + HEAD_DIM=head_dim, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + # Cast back to original dtypes + return dq_f.to(q.dtype), dk_rep_f.to(k.dtype), dv_rep_f.to(v.dtype), None, None + + @staticmethod + def backward2(ctx, do): + q, k, v, o, M, L = ctx.saved_tensors + batch, n_heads, seq_len, head_dim = q.shape + n_kv_heads = ctx.num_kv_heads + + # Implement backward using a numerically stable streaming algorithm with block-wise keys. + # This avoids allocating full (seq_len x seq_len) attention tensors. + scale = ctx.softmax_scale + + # Repeat K/V across groups to align with query heads (GQA expansion) + num_groups = n_heads // n_kv_heads + if num_groups == 1: + k_rep = k + v_rep = v + else: + k_rep = k.unsqueeze(2).expand(batch, n_kv_heads, num_groups, seq_len, head_dim).reshape(batch, n_heads, seq_len, head_dim) + v_rep = v.unsqueeze(2).expand(batch, n_kv_heads, num_groups, seq_len, head_dim).reshape(batch, n_heads, seq_len, head_dim) + + # Use fp32 for stability + q_f = q.to(torch.float32) + k_rep_f = k_rep.to(torch.float32) + v_rep_f = v_rep.to(torch.float32) + do_f = do.to(torch.float32) + o_f = o.to(torch.float32) + + # delta = (dO · O) per row (B, Hq, N, 1) + delta = (do_f * o_f).sum(dim=-1, keepdim=True) + + # Use stored per-row softmax stats from forward + m_i = M.to(torch.float32) + l_i = L.to(torch.float32).clamp_min(1e-12) + + row_idx = torch.arange(seq_len, device=q.device) + BLOCK_N = 128 + LOG2E = 1.4426950408889634 + scale2 = scale * LOG2E + + # Second pass: compute grads + dq_f = torch.zeros_like(q_f) + dk_rep_f = torch.zeros_like(k_rep_f) + dv_rep_f = torch.zeros_like(v_rep_f) + + for start_n in range(0, seq_len, BLOCK_N): + end_n = min(seq_len, start_n + BLOCK_N) + k_block = k_rep_f[:, :, start_n:end_n, :] # (B, Hq, BN, D) + v_block = v_rep_f[:, :, start_n:end_n, :] # (B, Hq, BN, D) + + s2_block = torch.matmul(q_f, k_block.transpose(-1, -2)) * scale2 + col_idx = torch.arange(start_n, end_n, device=q.device) + valid = (col_idx.view(1, 1, 1, -1) <= row_idx.view(1, 1, -1, 1)) + s2_block = s2_block.masked_fill(~valid, -float('inf')) + + # Reconstruct probabilities in the same base-2 domain as forward + p_tilde = torch.pow(2.0, s2_block - m_i.unsqueeze(-1)) + denom = l_i.unsqueeze(-1) + p_block = p_tilde / denom + + # dV += P^T @ dO + dv_rep_f[:, :, start_n:end_n, :] += torch.matmul(p_block.transpose(-1, -2), do_f) + + # t = dO @ V^T + t_block = torch.matmul(do_f, v_block.transpose(-1, -2)) + + # dS = P ⊙ (t - delta) + dS_block = p_block * (t_block - delta) + + # dQ += dS @ K + dq_f += torch.matmul(dS_block, k_block) * scale + + # dK += dS^T @ Q + dk_rep_f[:, :, start_n:end_n, :] += torch.matmul(dS_block.transpose(-1, -2), q_f) * scale + + # Collapse group dimension back to KV heads for dK and dV + if num_groups == 1: + dk_f = dk_rep_f + dv_f = dv_rep_f + else: + dk_f = dk_rep_f.view(batch, n_kv_heads, num_groups, seq_len, head_dim).sum(dim=2) + dv_f = dv_rep_f.view(batch, n_kv_heads, num_groups, seq_len, head_dim).sum(dim=2) + + dq = dq_f.to(q.dtype) + dk = dk_f.to(k.dtype) + dv = dv_f.to(v.dtype) + return dq, dk, dv, None, None + def flash_attention_gqa(q, k, v, is_causal=True, softmax_scale=None): return FlashAttention2Function.apply(q, k, v, is_causal, softmax_scale) \ No newline at end of file diff --git a/problem_9.py b/problem_9.py index 126b416..9ace9f0 100644 --- a/problem_9.py +++ b/problem_9.py @@ -3,17 +3,18 @@ import triton.language as tl import math from typing import Optional - +#(dhruv) passes all test cases in T4 but not in local gpu @triton.jit def _flash_attention_forward_swa_kernel( # Pointers to Tensors - Q_ptr, K_ptr, V_ptr, O_ptr, M_ptr, + Q_ptr, K_ptr, V_ptr, O_ptr, M_ptr, L_ptr, # Stride information for tensors q_stride_b, q_stride_h, q_stride_s, k_stride_b, k_stride_h, k_stride_s, v_stride_b, v_stride_h, v_stride_s, o_stride_b, o_stride_h, o_stride_s, m_stride_b, m_stride_h, m_stride_s, + l_stride_b, l_stride_h, l_stride_s, # Kernel parameters softmax_scale, SEQ_LEN, @@ -26,37 +27,433 @@ def _flash_attention_forward_swa_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - pass + """ + Triton kernel for the forward pass of causal FlashAttention with GQA, Sliding Window Attention, and Attention Sink. + """ + # 1. Identify the block of queries and the batch/head to be processed. + q_block_idx = tl.program_id(axis=0) + batch_head_idx = tl.program_id(axis=1) + + batch_idx = batch_head_idx // N_Q_HEADS + q_head_idx = batch_head_idx % N_Q_HEADS + + # --- GQA Logic: Map Query Head to Shared K/V Head --- + num_groups = N_Q_HEADS // N_KV_HEADS + kv_head_idx = q_head_idx // num_groups + + # 2. Initialize accumulators in SRAM. + m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + # 3. Load the block of queries (Q_i). + q_offsets = (q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M)) + q_ptrs = Q_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + \ + (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0) + q_block = tl.cast(q_block, tl.float32) + qk_scale = softmax_scale * 1.44269504 + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Combine the GQA, SWA, and Sink logic. + # Combine all code from previous problems, and add the sink logic. + # You should have 3 phases: + # 1. Phase 0: Sink blocks that are before the sliding window + # 2. Phase 1: Off-Diagonal Blocks (within the window) + # 3. Phase 2: Diagonal Blocks + window_start = max(0, q_block_idx * BLOCK_M - WINDOW_SIZE) + # --- Phase 0: Sink blocks that are before the sliding window --- + sink_end = min(SINK_SIZE, window_start) + for start_n in range(0, sink_end, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + k_block = tl.cast(k_block, tl.float32) + s_ij = tl.dot(q_block, k_block, allow_tf32=False) + s_ij *= qk_scale + # Attend only to causal sink tokens within sequence length + valid = (k_offsets[None, :] <= q_offsets[:, None]) & (k_offsets[None, :] < SINK_SIZE) & (k_offsets[None, :] < SEQ_LEN) + s_ij = tl.where(valid, s_ij, -float('inf')) + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + v_block = tl.cast(v_block, tl.float32) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) + # 2. Rescale accumulators; guard all-masked tiles to avoid NaNs. + no_contrib = m_new == -float('inf') + alpha = tl.where(no_contrib, 1.0, tl.exp2(m_i - m_new)) + acc_rescaled = acc * alpha[:, None] + l_i_rescaled = l_i * alpha + # 3. Compute probabilities safely. + s_shifted = s_ij - m_new[:, None] + s_shifted = tl.where(no_contrib[:, None], -float('inf'), s_shifted) + P_tilde_ij = tl.exp2(s_shifted) + # 4. Update accumulators. + l_i = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij, v_block, allow_tf32=False) + # 5. Update running maximum for next iteration. + m_i = m_new + # --- STUDENT IMPLEMENTATION REQUIRED (Part 2: SWA Logic) --- + # Now, implement the "sliding window" by changing the loop bounds. + # The kernel should only attend to the `WINDOW_SIZE` most recent key/value tokens. + # 1. Calculate the starting position of the attention window (window_start). + # 2. Modify the range of the Phase 1 loop to start from your window_start. + + + + # --- Phase 1: Off-Diagonal Blocks (within the window) --- + for start_n in range(window_start, q_block_idx * BLOCK_M, BLOCK_N): + # STUDENT IMPLEMENTATION REQUIRED (Part 3: SWA Logic) + # Hint: You might need to apply the per-element sliding window mask to s_ij. + # - A score is invalid if `(query_offset - key_offset) >= WINDOW_SIZE`. + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + k_block = tl.cast(k_block, tl.float32) + + # Compute attention scores S_ij = Q_i * K_j^T + s_ij = tl.dot(q_block, k_block, allow_tf32=False) + s_ij *= qk_scale + causal = (k_offsets[None, :] <= q_offsets[:, None]) + within_window = (q_offsets[:, None] - k_offsets[None, :]) < WINDOW_SIZE + in_sink = (k_offsets[None, :] < SINK_SIZE) + in_range = (k_offsets[None, :] < SEQ_LEN) + valid = causal & (within_window | in_sink) & in_range + s_ij = tl.where(valid, s_ij, -float('inf')) + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + v_block = tl.cast(v_block, tl.float32) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) + # 2. Rescale accumulators; guard all-masked tiles to avoid NaNs. + no_contrib = m_new == -float('inf') + alpha = tl.where(no_contrib, 1.0, tl.exp2(m_i - m_new)) + acc_rescaled = acc * alpha[:, None] + l_i_rescaled = l_i * alpha + # 3. Compute probabilities safely. + s_shifted = s_ij - m_new[:, None] + s_shifted = tl.where(no_contrib[:, None], -float('inf'), s_shifted) + P_tilde_ij = tl.exp2(s_shifted) + # 4. Update accumulators. + l_i = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij, v_block, allow_tf32=False) + # 5. Update running maximum for next iteration. + m_i = m_new + # --- END OF STUDENT IMPLEMENTATION --- + # --- END OF STUDENT IMPLEMENTATION --- + + # --- Phase 2: Diagonal Blocks --- + diag_start = q_block_idx * BLOCK_M + for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + k_block = tl.cast(k_block, tl.float32) + + # Compute attention scores S_ij = Q_i * K_j^T + s_ij = tl.dot(q_block, k_block, allow_tf32=False) + s_ij *= qk_scale + # Causal + sliding window + sink mask within the diagonal block + causal = (k_offsets[None, :] <= q_offsets[:, None]) + within_window = (q_offsets[:, None] - k_offsets[None, :]) < WINDOW_SIZE + in_sink = (k_offsets[None, :] < SINK_SIZE) + in_range = (k_offsets[None, :] < SEQ_LEN) + valid = causal & (within_window | in_sink) & in_range + s_ij = tl.where(valid, s_ij, -float('inf')) + # mask = q_offsets[:, None]-k_offsets[None, :] >= WINDOW_SIZE + # s_ij = tl.where(mask, s_ij, -float('inf')) + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + v_block = tl.cast(v_block, tl.float32) + + # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + # Implement the online softmax update logic (streaming, numerically stable). + # Ensure consistent fp32 dtype for reductions and dot products. + # q_block = tl.cast(q_block, tl.float32) + # k_block = tl.cast(k_block, tl.float32) + # v_block = tl.cast(v_block, tl.float32) + # s_ij = tl.cast(s_ij, tl.float32) + + # 1. Find the new running maximum (`m_new`). + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) + # 2. Rescale accumulators; guard all-masked tiles to avoid NaNs. + no_contrib = m_new == -float('inf') + alpha = tl.where(no_contrib, 1.0, tl.exp2(m_i - m_new)) + acc_rescaled = acc * alpha[:, None] + l_i_rescaled = l_i * alpha + # 3. Compute probabilities safely. + s_shifted = s_ij - m_new[:, None] + s_shifted = tl.where(no_contrib[:, None], -float('inf'), s_shifted) + P_tilde_ij = tl.exp2(s_shifted) + # 4. Update accumulators. + l_i = l_i_rescaled + tl.sum(P_tilde_ij, axis=1) + acc = acc_rescaled + tl.dot(P_tilde_ij, v_block, allow_tf32=False) + # 5. Update running maximum for next iteration. + m_i = m_new + # --- END OF STUDENT IMPLEMENTATION --- + # --- END OF STUDENT IMPLEMENTATION --- + + # 4. Normalize and write the final output block. + l_i_safe = tl.where(l_i == 0, 1.0, l_i) + acc = acc / l_i_safe[:, None] + + o_ptrs = O_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + \ + (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + + tl.store(o_ptrs, acc.to(O_ptr.dtype.element_ty), mask=q_offsets[:, None] < SEQ_LEN) + temp = (q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M)) + # Use (B,H,S) strides when storing per-row stats + m_ptrs = M_ptr + batch_idx * m_stride_b + q_head_idx * m_stride_h + temp * m_stride_s + tl.store(m_ptrs, m_i.to(M_ptr.dtype.element_ty), mask=(temp < SEQ_LEN)) + l_ptrs = L_ptr + batch_idx * l_stride_b + q_head_idx * l_stride_h + temp * l_stride_s + tl.store(l_ptrs, l_i.to(L_ptr.dtype.element_ty), mask=(temp < SEQ_LEN)) @triton.jit def _flash_attention_backward_swa_kernel( - # In/Out Pointers - Q_ptr, K_ptr, V_ptr, dO_ptr, M_ptr, D_ptr, + # Inputs + Q_ptr, K_ptr, V_ptr, O_ptr, M_ptr, L_ptr, dO_ptr, + # Outputs (grads) dQ_ptr, dK_ptr, dV_ptr, - # Strides + # Strides for Q/K/V/O (assumed same layout for Q, O, dO, dQ) q_stride_b, q_stride_h, q_stride_s, k_stride_b, k_stride_h, k_stride_s, v_stride_b, v_stride_h, v_stride_s, - do_stride_b, do_stride_h, do_stride_s, + # Strides for M/L (B,H,S) m_stride_b, m_stride_h, m_stride_s, - d_stride_b, d_stride_h, d_stride_s, - dq_stride_b, dq_stride_h, dq_stride_s, - dk_stride_b, dk_stride_h, dk_stride_s, - dv_stride_b, dv_stride_h, dv_stride_s, - # Parameters + l_stride_b, l_stride_h, l_stride_s, + # Kernel parameters softmax_scale, - BATCH_SIZE: int, - N_Q_HEADS: int, - N_KV_HEADS: int, - SEQ_LEN: int, + SEQ_LEN, + N_Q_HEADS, + N_KV_HEADS, + # Constexpr tile sizes WINDOW_SIZE: tl.constexpr, SINK_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr, - # Tile Sizes BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - pass + LOG2E = 1.44269504 + qk_scale2 = softmax_scale * LOG2E + + # Program ids + q_block_idx = tl.program_id(axis=0) + bh_idx = tl.program_id(axis=1) + batch_idx = bh_idx // N_Q_HEADS + q_head_idx = bh_idx % N_Q_HEADS + + # GQA mapping: map q_head to kv_head + num_groups = N_Q_HEADS // N_KV_HEADS + kv_head_idx = q_head_idx // num_groups + + # Offsets and pointers for this query block + q_offsets = q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M) + hd = tl.arange(0, HEAD_DIM) + q_ptrs = Q_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + (q_offsets[:, None] * q_stride_s + hd[None, :]) + o_ptrs = O_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + (q_offsets[:, None] * q_stride_s + hd[None, :]) + do_ptrs = dO_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + (q_offsets[:, None] * q_stride_s + hd[None, :]) + + q_mask = q_offsets[:, None] < SEQ_LEN + q_block = tl.load(q_ptrs, mask=q_mask, other=0.0) + o_block = tl.load(o_ptrs, mask=q_mask, other=0.0) + do_block = tl.load(do_ptrs, mask=q_mask, other=0.0) + q_block = tl.cast(q_block, tl.float32) + o_block = tl.cast(o_block, tl.float32) + do_block = tl.cast(do_block, tl.float32) + # Load stored per-row softmax stats + m_ptrs = M_ptr + batch_idx * m_stride_b + q_head_idx * m_stride_h + q_offsets * m_stride_s + l_ptrs = L_ptr + batch_idx * l_stride_b + q_head_idx * l_stride_h + q_offsets * l_stride_s + m_i = tl.load(m_ptrs, mask=q_offsets < SEQ_LEN, other=0.0) + l_i = tl.load(l_ptrs, mask=q_offsets < SEQ_LEN, other=1.0) + l_i = tl.maximum(l_i, 1e-12) + + # delta = sum(dO * O) per row + delta = tl.sum(do_block * o_block, axis=1) # (BLOCK_M) + + # Accumulator for dQ + dQ_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + # --- Phase 1: Off-Diagonal Blocks --- + window_start = max(0, q_block_idx * BLOCK_M - WINDOW_SIZE) + # --- Phase 0: Sink blocks that are before the sliding window --- + sink_end = min(SINK_SIZE, window_start) + for start_n in range(0, sink_end, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_mask = k_offsets < SEQ_LEN + + # Load K and V once in column-major (D, N), derive row-major via transpose + k_cols_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + v_cols_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + (k_offsets[None, :] * v_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_cols = tl.load(k_cols_ptrs, mask=k_mask[None, :], other=0.0) # (D, N) + v_cols = tl.load(v_cols_ptrs, mask=k_mask[None, :], other=0.0) # (D, N) + # Cast tiles to fp32 for matmuls + k_cols = tl.cast(k_cols, tl.float32) + v_cols = tl.cast(v_cols, tl.float32) + k_rows = tl.trans(k_cols) # (N, D) + + # Scores in base-2 domain and probabilities + q_f32 = tl.cast(q_block, tl.float32) + s2 = tl.dot(q_f32, k_cols, allow_tf32=False) * qk_scale2 # (M, N) + causal = (k_offsets[None, :] <= q_offsets[:, None]) + in_sink = (k_offsets[None, :] < SINK_SIZE) + in_range = (k_offsets[None, :] < SEQ_LEN) + valid = causal & in_sink & in_range + s2 = tl.where(valid, s2, -float('inf')) + p_tilde = tl.exp2(s2 - m_i[:, None]) + P = p_tilde / l_i[:, None] + + # dV partial via matmul: P^T @ dO => (N,D) + do_f32 = tl.cast(do_block, tl.float32) + dv_partial = tl.dot(tl.trans(P), do_f32, allow_tf32=False) + + # t_block = dO @ V^T using V in (D, N) + t_block = tl.dot(do_f32, v_cols, allow_tf32=False) + + # dS = P * (t_block - delta[:, None]) + dS = P * (t_block - delta[:, None]) + + # dQ += dS @ K_rows, scaled by softmax_scale + dQ_acc += tl.dot(dS, k_rows, allow_tf32=False) * softmax_scale + + # dK partial via matmul: dS^T @ Q, scaled by softmax_scale + q_f32 = tl.cast(q_block, tl.float32) + dk_partial = tl.dot(tl.trans(dS), q_f32, allow_tf32=False) * softmax_scale + + # Atomic add into global dV and dK + dv_ptrs = dV_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + (k_offsets[:, None] * v_stride_s + hd[None, :]) + dk_ptrs = dK_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + (k_offsets[:, None] * k_stride_s + hd[None, :]) + tl.atomic_add(dv_ptrs, dv_partial, mask=k_mask[:, None]) + tl.atomic_add(dk_ptrs, dk_partial, mask=k_mask[:, None]) + for start_n in range(window_start, q_block_idx * BLOCK_M, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_mask = k_offsets < SEQ_LEN + + # Load K and V once in column-major (D, N), derive row-major via transpose + k_cols_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + v_cols_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + (k_offsets[None, :] * v_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_cols = tl.load(k_cols_ptrs, mask=k_mask[None, :], other=0.0) # (D, N) + v_cols = tl.load(v_cols_ptrs, mask=k_mask[None, :], other=0.0) # (D, N) + # Cast tiles to fp32 for matmuls + k_cols = tl.cast(k_cols, tl.float32) + v_cols = tl.cast(v_cols, tl.float32) + k_rows = tl.trans(k_cols) # (N, D) + + # Scores in base-2 domain and probabilities + q_f32 = tl.cast(q_block, tl.float32) + s2 = tl.dot(q_f32, k_cols, allow_tf32=False) * qk_scale2 # (M, N) + causal = (k_offsets[None, :] <= q_offsets[:, None]) + within_window = (q_offsets[:, None] - k_offsets[None, :]) < WINDOW_SIZE + in_sink = (k_offsets[None, :] < SINK_SIZE) + in_range = (k_offsets[None, :] < SEQ_LEN) + valid = causal & (within_window | in_sink) & in_range + s2 = tl.where(valid, s2, -float('inf')) + p_tilde = tl.exp2(s2 - m_i[:, None]) + P = p_tilde / l_i[:, None] + + # dV partial via matmul: P^T @ dO => (N,D) + do_f32 = tl.cast(do_block, tl.float32) + dv_partial = tl.dot(tl.trans(P), do_f32, allow_tf32=False) + + # t_block = dO @ V^T using V in (D, N) + t_block = tl.dot(do_f32, v_cols, allow_tf32=False) + + # dS = P * (t_block - delta[:, None]) + dS = P * (t_block - delta[:, None]) + + # dQ += dS @ K_rows, scaled by softmax_scale + dQ_acc += tl.dot(dS, k_rows, allow_tf32=False) * softmax_scale + + # dK partial via matmul: dS^T @ Q, scaled by softmax_scale + q_f32 = tl.cast(q_block, tl.float32) + dk_partial = tl.dot(tl.trans(dS), q_f32, allow_tf32=False) * softmax_scale + + # Atomic add into global dV and dK + dv_ptrs = dV_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + (k_offsets[:, None] * v_stride_s + hd[None, :]) + dk_ptrs = dK_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + (k_offsets[:, None] * k_stride_s + hd[None, :]) + tl.atomic_add(dv_ptrs, dv_partial, mask=k_mask[:, None]) + tl.atomic_add(dk_ptrs, dk_partial, mask=k_mask[:, None]) + + # --- Phase 2: Diagonal Blocks --- + diag_start = q_block_idx * BLOCK_M + for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_mask = k_offsets < SEQ_LEN + + k_cols_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + v_cols_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + (k_offsets[None, :] * v_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_cols = tl.load(k_cols_ptrs, mask=k_mask[None, :], other=0.0) + v_cols = tl.load(v_cols_ptrs, mask=k_mask[None, :], other=0.0) + k_cols = tl.cast(k_cols, tl.float32) + v_cols = tl.cast(v_cols, tl.float32) + k_rows = tl.trans(k_cols) + + # Scores with causal mask inside diagonal tile + q_f32 = tl.cast(q_block, tl.float32) + + s2 = tl.dot(q_f32, k_cols, allow_tf32=False) * qk_scale2 + causal = (k_offsets[None, :] <= q_offsets[:, None]) + within_window = (q_offsets[:, None] - k_offsets[None, :]) < WINDOW_SIZE + in_sink = (k_offsets[None, :] < SINK_SIZE) + in_range = (k_offsets[None, :] < SEQ_LEN) + valid = causal & (within_window | in_sink) & in_range + s2 = tl.where(valid, s2, -float('inf')) + p_tilde = tl.exp2(s2 - m_i[:, None]) + P = p_tilde / l_i[:, None] + + # dV partial P^T @ dO + do_f32 = tl.cast(do_block, tl.float32) + dv_partial = tl.dot(tl.trans(P), do_f32, allow_tf32=False) + + # t_block = dO @ V^T + t_block = tl.dot(do_f32, v_cols, allow_tf32=False) + t_block = tl.where(valid, t_block, 0.0) + + dS = P * (t_block - delta[:, None]) + dS = tl.where(valid, dS, 0.0) + + dQ_acc += tl.dot(dS, k_rows, allow_tf32=False) * softmax_scale + + q_f32 = tl.cast(q_block, tl.float32) + dk_partial = tl.dot(tl.trans(dS), q_f32, allow_tf32=False) * softmax_scale + + dv_ptrs = dV_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + (k_offsets[:, None] * v_stride_s + hd[None, :]) + dk_ptrs = dK_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + (k_offsets[:, None] * k_stride_s + hd[None, :]) + tl.atomic_add(dv_ptrs, dv_partial, mask=k_mask[:, None]) + tl.atomic_add(dk_ptrs, dk_partial, mask=k_mask[:, None]) + + # Write dQ for this block + dQ_ptrs = dQ_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + (q_offsets[:, None] * q_stride_s + hd[None, :]) + tl.store(dQ_ptrs, dQ_acc, mask=q_mask) class FlashSWDAWithSink(torch.autograd.Function): @staticmethod @@ -76,18 +473,19 @@ def forward(ctx, q, k, v, window_size, sink_size, is_causal=True, softmax_scale= o = torch.empty_like(q) M = torch.empty((batch, n_q_heads, seq_len), device=q.device, dtype=torch.float32) - + L = torch.empty((batch, n_q_heads, seq_len), device=q.device, dtype=torch.float32) BLOCK_M, BLOCK_N = 128, 64 grid = (math.ceil(seq_len / BLOCK_M), batch * n_q_heads) _flash_attention_forward_swa_kernel[grid]( - q, k, v, o, M, + q, k, v, o, M, L, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), o.stride(0), o.stride(1), o.stride(2), M.stride(0), M.stride(1), M.stride(2), + L.stride(0), L.stride(1), L.stride(2), softmax_scale, seq_len, n_q_heads, @@ -99,7 +497,7 @@ def forward(ctx, q, k, v, window_size, sink_size, is_causal=True, softmax_scale= BLOCK_N=BLOCK_N, ) - ctx.save_for_backward(q, k, v, o, M) + ctx.save_for_backward(q, k, v, o, M, L) ctx.softmax_scale = softmax_scale ctx.window_size = window_size ctx.sink_size = sink_size @@ -107,7 +505,7 @@ def forward(ctx, q, k, v, window_size, sink_size, is_causal=True, softmax_scale= @staticmethod def backward(ctx, do): - q, k, v, o, M = ctx.saved_tensors + q, k, v, o, M, L = ctx.saved_tensors softmax_scale = ctx.softmax_scale window_size = ctx.window_size sink_size = ctx.sink_size @@ -115,13 +513,35 @@ def backward(ctx, do): batch, n_q_heads, seq_len, head_dim = q.shape n_kv_heads = k.shape[1] - dq = torch.empty_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - # TODO: Add your backward kernel here + dq_f = torch.zeros_like(q, dtype=torch.float32) + dk_rep_f = torch.zeros_like(k, dtype=torch.float32) + dv_rep_f = torch.zeros_like(v, dtype=torch.float32) + + # Implement backward using a numerically stable streaming algorithm with block-wise keys. + # This avoids allocating full (seq_len x seq_len) attention tensors. + softmax_scale = ctx.softmax_scale + BLOCK_M, BLOCK_N = 128, 64 + grid = (triton.cdiv(seq_len, BLOCK_M), batch * n_q_heads) + _flash_attention_backward_swa_kernel[grid]( + q, k, v, o, M, L, do, + dq_f, dk_rep_f, dv_rep_f, + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + M.stride(0), M.stride(1), M.stride(2), + L.stride(0), L.stride(1), L.stride(2), + softmax_scale, + seq_len, + n_q_heads, + n_kv_heads, + WINDOW_SIZE=window_size, + SINK_SIZE=sink_size, + HEAD_DIM=head_dim, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) - return dq, dk.to(k.dtype), dv.to(v.dtype), None, None, None, None + return dq_f.to(q.dtype), dk_rep_f.to(k.dtype), dv_rep_f.to(v.dtype), None, None, None, None def flash_swda_with_sink(q, k, v, window_size: int, sink_size: int = 0, is_causal: bool = True, scale: Optional[float] = None): return FlashSWDAWithSink.apply(q, k, v, window_size, sink_size, is_causal, scale) \ No newline at end of file From 2024c84434593b1af5f69bd34af0af55c1ff6d89 Mon Sep 17 00:00:00 2001 From: Dhruv Date: Sun, 14 Sep 2025 18:39:05 +0530 Subject: [PATCH 2/2] all problems till 9 working as expected --- problem_6.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/problem_6.py b/problem_6.py index 84809b6..cd892ff 100644 --- a/problem_6.py +++ b/problem_6.py @@ -61,7 +61,7 @@ def _flash_attention_forward_swa_kernel( # 1. Calculate the starting position of the attention window (window_start). # 2. Modify the range of the Phase 1 loop to start from your window_start. - window_start = tl.max(0, q_block_idx * BLOCK_M - (WINDOW_SIZE )) + window_start = max(0, q_block_idx * BLOCK_M - (WINDOW_SIZE )) # --- Phase 1: Off-Diagonal Blocks (within the window) --- for start_n in range(window_start, q_block_idx * BLOCK_M, BLOCK_N):