diff --git a/problem_1.py b/problem_1.py index 683be1c..310a80b 100644 --- a/problem_1.py +++ b/problem_1.py @@ -56,18 +56,28 @@ def forward(ctx, Q, K, V, is_causal=False): S_ij = (Q_tile @ K_tile.transpose(-1, -2)) * scale # --- STUDENT IMPLEMENTATION REQUIRED HERE --- + S_ij = S_ij.to(torch.float32) + V_tile = V_tile.to(torch.float32) # 1. Apply causal masking if is_causal is True. - # + if is_causal: + q_idx = torch.arange(q_start, q_end, device=Q.device)[:, None] + k_idx = torch.arange(k_start, k_end, device=Q.device)[None, :] + mask = k_idx > q_idx # (q_len, k_len) + S_ij = S_ij.masked_fill(mask, torch.finfo(torch.float32).min) # 2. Compute the new running maximum - # + m_ij = torch.max(S_ij, dim=-1).values #(128,) + m_new = torch.maximum(m_i, m_ij) #m_i is (128,) -> m_new is (128,) # 3. Rescale the previous accumulators (o_i, l_i) - # + scale_factor = torch.exp(m_i - m_new) #(128,) + o_i = o_i * scale_factor.unsqueeze(-1) # o_i is (128,16) -> (128,16) + l_i = l_i * scale_factor # l_i is (128,) -> (128,) # 4. Compute the probabilities for the current tile, P_tilde_ij = exp(S_ij - m_new). - # + P_tilde_ij = torch.exp(S_ij - m_new.unsqueeze(-1)) #(128,128) # 5. Accumulate the current tile's contribution to the accumulators to update l_i and o_i - # + o_i = o_i + (P_tilde_ij @ V_tile) #V_tile is (128,16) + l_i = l_i + P_tilde_ij.sum(dim=-1) #P_tilde_ij.sum(dim=-1) -> (128,) sum over cols (row-wise) # 6. Update the running max for the next iteration - + m_i = m_new # --- END OF STUDENT IMPLEMENTATION --- # After iterating through all key tiles, normalize the output diff --git a/problem_2.py b/problem_2.py index 011f3eb..fdc8c13 100644 --- a/problem_2.py +++ b/problem_2.py @@ -16,47 +16,47 @@ def weighted_row_sum_kernel( """ # 1. Get the row index for the current program instance. # Hint: Use tl.program_id(axis=0). - row_idx = ... + row_idx = tl.program_id(axis=0) # 2. Create a pointer to the start of the current row in the input tensor X. # Hint: The offset depends on the row index and the number of columns (N_COLS). - row_start_ptr = ... + row_start_ptr = X_ptr + row_idx * N_COLS # 3. Create a pointer for the output vector Y. - output_ptr = ... + output_ptr = Y_ptr + row_idx # 4. Initialize an accumulator for the sum of the products for a block. # This should be a block-sized tensor of zeros. # Hint: Use tl.zeros with shape (BLOCK_SIZE,) and dtype tl.float32. - accumulator = ... + accumulator = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) # 5. Iterate over the columns of the row in blocks of BLOCK_SIZE. # Hint: Use a for loop with tl.cdiv(N_COLS, BLOCK_SIZE). - for col_block_start in range(0, ...): + for col_block_start in range(0, tl.cdiv(N_COLS, BLOCK_SIZE)): # - Calculate the offsets for the current block of columns. # Hint: Start from the block's beginning and add tl.arange(0, BLOCK_SIZE). - col_offsets = ... + col_offsets = col_block_start * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # - Create a mask to prevent out-of-bounds memory access for the last block. # Hint: Compare col_offsets with N_COLS. - mask = ... + mask = col_offsets < N_COLS # - Load a block of data from X and W safely using the mask. # Hint: Use tl.load with the appropriate pointers, offsets, and mask. # Use `other=0.0` to handle out-of-bounds elements. - x_chunk = tl.load(...) - w_chunk = tl.load(...) - + x_chunk = tl.load(row_start_ptr + col_offsets, mask=mask, other=0.0) #if mask=False, load 0.0 + w_chunk = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + # - Compute the element-wise product and add it to the accumulator. - accumulator += ... - + accumulator += x_chunk * w_chunk + # 6. Reduce the block-sized accumulator to a single scalar value after the loop. # Hint: Use tl.sum(). - final_sum = ... + final_sum = tl.sum(accumulator, axis=0) # 7. Store the final accumulated sum to the output tensor Y. # Hint: Use tl.store(). - ... + tl.store(output_ptr, final_sum) #store tensor of data into mem locs def by output_ptr # --- END OF STUDENT IMPLEMENTATION --- @@ -94,4 +94,5 @@ def torch_weighted_row_sum(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: """ Reference implementation using pure PyTorch. """ - return (x * w).sum(dim=1) \ No newline at end of file + y = (x * w).sum(dim=1) + return y.to(x.dtype) \ No newline at end of file diff --git a/problem_3.py b/problem_3.py index 43f134a..60d7a7b 100644 --- a/problem_3.py +++ b/problem_3.py @@ -63,13 +63,23 @@ def _flash_attention_forward_kernel( # --- STUDENT IMPLEMENTATION REQUIRED HERE --- # Implement the online softmax update logic. + s_ij = s_ij.to(tl.float32) + v_block = v_block.to(tl.float32) # 1. Find the new running maximum (`m_new`). + m_ij = tl.max(s_ij, axis=1) #(BLOCK_M,) + m_new = tl.maximum(m_i, m_ij) #(BLOCK_M,) # 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`). + scale_factor = tl.exp2(m_i - m_new) #(BLOCK_M,) + acc = acc * scale_factor[:, None] #(BLOCK_M, HEAD_DIM) + l_i = l_i * scale_factor #(BLOCK_M,) # 3. Compute the attention probabilities for the current tile (`p_ij`). + p_ij = tl.exp2(s_ij - m_new[:, None]) #(BLOCK_M, BLOCK_N) # 4. Update the accumulator `acc` using `p_ij` and `v_block`. + acc = acc + tl.dot(p_ij, v_block) #(BLOCK_M, HEAD_DIM) # 5. Update the denominator `l_i`. + l_i = l_i + tl.sum(p_ij, axis=1) #(BLOCK_M,) # 6. Update the running maximum `m_i` for the next iteration. - pass + m_i = m_new # --- END OF STUDENT IMPLEMENTATION --- diff --git a/problem_4.py b/problem_4.py index 4faf01b..c2a9c7b 100644 --- a/problem_4.py +++ b/problem_4.py @@ -52,9 +52,40 @@ def _flash_attention_forward_causal_kernel( # Implement the logic for the off-diagonal blocks. # This is very similar to the non-causal version from Problem 3. # 1. Load the K and V blocks for the current iteration. + # Load K_j + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + # 2. Compute the attention scores (S_ij). + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + + s_ij = s_ij.to(tl.float32) + v_block = v_block.to(tl.float32) + # 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc). - pass + # 1. Find the new running maximum (`m_new`). + m_ij = tl.max(s_ij, axis=1) #(BLOCK_M,) + m_new = tl.maximum(m_i, m_ij) #(BLOCK_M,) + # 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`). + scale_factor = tl.exp2(m_i - m_new) #(BLOCK_M,) + acc = acc * scale_factor[:, None] #(BLOCK_M, HEAD_DIM) + l_i = l_i * scale_factor #(BLOCK_M,) + # 3. Compute the attention probabilities for the current tile (`p_ij`). + p_ij = tl.exp2(s_ij - m_new[:, None]) #(BLOCK_M, BLOCK_N) + # 4. Update the accumulator `acc` using `p_ij` and `v_block`. + acc = acc + tl.dot(p_ij, v_block) #(BLOCK_M, HEAD_DIM) + # 5. Update the denominator `l_i`. + l_i = l_i + tl.sum(p_ij, axis=1) #(BLOCK_M,) + # 6. Update the running maximum `m_i` for the next iteration. + m_i = m_new # --- END OF STUDENT IMPLEMENTATION --- @@ -64,7 +95,44 @@ def _flash_attention_forward_causal_kernel( for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): # --- STUDENT IMPLEMENTATION REQUIRED HERE --- # Implement the logic for the diagonal blocks, apply the causal mask to S_ij. - pass + # Load K_j + k_offsets = start_n + tl.arange(0, BLOCK_N) # (BLOCK_N,) + k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + + s_ij = s_ij.to(tl.float32) + v_block = v_block.to(tl.float32) + + # Build mask + q_idx = q_offsets + k_idx = k_offsets + causal = q_idx[:, None] >= k_idx[None, :] #Lower triangle true + valid = (q_idx[:, None] < SEQ_LEN) & (k_idx[None, :] < SEQ_LEN) + mask = causal & valid + + # Apply mask BEFORE tile max so future tokens don't affect m_i + neg_inf = -float("inf") + s_ij = tl.where(mask, s_ij, neg_inf) + + # online softmax update + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) + scale_factor = tl.exp2(m_i - m_new) + + p_ij = tl.exp2(s_ij - m_new[:, None]) + + acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block) + l_i = l_i * scale_factor + tl.sum(p_ij, axis=1) + m_i = m_new # --- END OF STUDENT IMPLEMENTATION --- diff --git a/problem_5.py b/problem_5.py index aad8fe1..ef20cc2 100644 --- a/problem_5.py +++ b/problem_5.py @@ -34,9 +34,9 @@ def _flash_attention_forward_gqa_kernel( # --- STUDENT IMPLEMENTATION REQUIRED HERE (Part 1) --- # Your goal is to map the current query head (q_head_idx) to its corresponding shared key/value head (kv_head_idx). # 1. Calculate how many query heads are in each group. + group_size = N_Q_HEADS // N_KV_HEADS # 2. Use integer division to find the correct kv_head_idx. - - kv_head_idx = 0 # Placeholder: Replace with your calculation + kv_head_idx = q_head_idx // group_size # --- END OF STUDENT IMPLEMENTATION --- @@ -59,7 +59,39 @@ def _flash_attention_forward_gqa_kernel( # 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`. # 2. Reuse your working implementation for the online softmax update # from your solution to Problem 4. - pass + # Load K_j + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # 2. Compute the attention scores (S_ij). + s_ij = tl.dot(q_block.to(tl.float32), k_block.to(tl.float32)) + s_ij *= qk_scale + + v_block = v_block.to(tl.float32) + + # 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc). + # 1. Find the new running maximum (`m_new`). + m_ij = tl.max(s_ij, axis=1) #(BLOCK_M,) + m_new = tl.maximum(m_i, m_ij) #(BLOCK_M,) + # 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`). + scale_factor = tl.exp2(m_i - m_new) #(BLOCK_M,) + acc = acc * scale_factor[:, None] #(BLOCK_M, HEAD_DIM) + l_i = l_i * scale_factor #(BLOCK_M,) + # 3. Compute the attention probabilities for the current tile (`p_ij`). + p_ij = tl.exp2(s_ij - m_new[:, None]) #(BLOCK_M, BLOCK_N) + # 4. Update the accumulator `acc` using `p_ij` and `v_block`. + acc = acc + tl.dot(p_ij, v_block) #(BLOCK_M, HEAD_DIM) + # 5. Update the denominator `l_i`. + l_i = l_i + tl.sum(p_ij, axis=1) #(BLOCK_M,) + # 6. Update the running maximum `m_i` for the next iteration. + m_i = m_new # --- END OF STUDENT IMPLEMENTATION --- # --- Phase 2: Diagonal Blocks --- @@ -69,7 +101,40 @@ def _flash_attention_forward_gqa_kernel( # 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`. # 2. Reuse your working implementation for the masked online softmax # update from your solution to Problem 4. - pass + # Load K_j + k_offsets = start_n + tl.arange(0, BLOCK_N) # (BLOCK_N,) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + s_ij = tl.dot(q_block.to(tl.float32), k_block.to(tl.float32)) + s_ij *= qk_scale + + v_block = v_block.to(tl.float32) + + # build mask + causal = q_offsets[:, None] >= k_offsets[None, :] #Lower triangle true + valid = (q_offsets[:, None] < SEQ_LEN) & (k_offsets[None, :] < SEQ_LEN) + mask = causal & valid + + # Apply mask BEFORE tile max so future tokens don't affect m_i + s_ij = tl.where(mask, s_ij, -float("inf")) + + # online softmax update + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) + scale_factor = tl.exp2(m_i - m_new) + + p_ij = tl.exp2(s_ij - m_new[:, None]) + + acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block) + l_i = l_i * scale_factor + tl.sum(p_ij, axis=1) + m_i = m_new # --- END OF STUDENT IMPLEMENTATION --- # 4. Normalize and write the final output block. diff --git a/problem_6.py b/problem_6.py index f097706..17cad80 100644 --- a/problem_6.py +++ b/problem_6.py @@ -35,8 +35,8 @@ def _flash_attention_forward_swa_kernel( # This problem combines GQA and SWA. First, implement the GQA logic. # 1. Calculate the number of query heads per group. # 2. Determine the correct kv_head_idx for the current q_head_idx. - - kv_head_idx = 0 # Placeholder: Replace with your GQA calculation + group_size = N_Q_HEADS // N_KV_HEADS + kv_head_idx = q_head_idx // group_size # --- END OF GQA IMPLEMENTATION --- @@ -49,7 +49,7 @@ def _flash_attention_forward_swa_kernel( q_offsets = (q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M)) q_ptrs = Q_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + \ (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :]) - q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0) + q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0).to(tl.float32) qk_scale = softmax_scale * 1.44269504 @@ -59,20 +59,119 @@ def _flash_attention_forward_swa_kernel( # 1. Calculate the starting position of the attention window (window_start). # 2. Modify the range of the Phase 1 loop to start from your window_start. - window_start = 0 # Placeholder: Replace with your SWA calculation + q_start = q_block_idx * BLOCK_M + win_left = q_start - (WINDOW_SIZE - 1) + window_start = tl.maximum(0, win_left) + # Previously, I did + # window_start = tl.maximum(0, q_block_idx * BLOCK_M - WINDOW_SIZE) + # which is wrong + + diag_start = q_block_idx * BLOCK_M # --- Phase 1: Off-Diagonal Blocks (within the window) --- for start_n in range(window_start, q_block_idx * BLOCK_M, BLOCK_N): # STUDENT IMPLEMENTATION REQUIRED (Part 3: SWA Logic) # Hint: You might need to apply the per-element sliding window mask to s_ij. # - A score is invalid if `(query_offset - key_offset) >= WINDOW_SIZE`. - pass + # CAUSAL! So only past keys. Therefore, no need to check for negative indices. + # Load K + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # 2. Compute the attention scores (S_ij). + k_block = k_block.to(tl.float32) + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + v_block = v_block.to(tl.float32) + + # Sliding window mask + dist = q_offsets[:, None] - k_offsets[None, :] #(BLOCK_M, BLOCK_N) + window_mask = (dist >= 0) & (dist < WINDOW_SIZE) + + # Validity mask - not required + # valid_mask = (q_offsets[:, None] < SEQ_LEN) & (k_offsets[None, :] < SEQ_LEN) + + # Prevent overlap with diagonal tile: + pre_diag_mask = k_offsets[None, :] < diag_start + + # Combine masks + mask = window_mask & pre_diag_mask + + s_ij = tl.where(mask, s_ij, -float('inf')) + + # Row has anything valid in this tile? + row_has = tl.max(mask, axis=1) > 0 + + # online softmax update + m_ij = tl.max(s_ij, axis=1) + # Only update rows that have something valid + m_new = tl.where(row_has, tl.maximum(m_i, m_ij), m_i) + # Calculate scale factor only for rows that have something valid + scale_factor = tl.where(row_has, tl.exp2(m_i - m_new), 1.0) + + # Probabilities only for rows that have something valid + p_ij = tl.where(row_has[:, None], tl.exp2(s_ij - m_new[:, None]), 0.0) + + acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block) + l_i = l_i * scale_factor + tl.sum(p_ij, axis=1) + m_i = m_new # --- Phase 2: Diagonal Blocks --- diag_start = q_block_idx * BLOCK_M for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): # STUDENT IMPLEMENTATION REQUIRED - pass + # Load K + k_offsets = start_n + tl.arange(0, BLOCK_N) # (BLOCK_N,) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # 2. Compute the attention scores (S_ij). + k_block = k_block.to(tl.float32) + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + v_block = v_block.to(tl.float32) + + # Sliding window mask + dist = q_offsets[:, None] - k_offsets[None, :] #(BLOCK_M, BLOCK_N) + # window_mask = (dist >= 0) & (dist < WINDOW_SIZE) + + # Combine masks + causal = q_offsets[:, None] >= k_offsets[None, :] #Lower triangle true + # valid = (q_offsets[:, None] < SEQ_LEN) & (k_offsets[None, :] < SEQ_LEN) + mask = causal + + # Apply mask BEFORE tile max so future tokens don't affect m_i + s_ij = tl.where(mask, s_ij, -float("inf")) + + # Row has anything valid in this tile? + row_has = tl.max(mask, axis=1) > 0 + + # online softmax update + m_ij = tl.max(s_ij, axis=1) + # Only update rows that have something valid + m_new = tl.where(row_has, tl.maximum(m_i, m_ij), m_i) + # Calculate scale factor only for rows that have something valid + scale_factor = tl.where(row_has, tl.exp2(m_i - m_new), 1.0) + + # Probabilities only for rows that have something valid + p_ij = tl.where(row_has[:, None], tl.exp2(s_ij - m_new[:, None]), 0.0) + + acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block) + l_i = l_i * scale_factor + tl.sum(p_ij, axis=1) + m_i = m_new # --- END OF SWA IMPLEMENTATION --- diff --git a/problem_7.py b/problem_7.py index af1558e..f5f5624 100644 --- a/problem_7.py +++ b/problem_7.py @@ -57,7 +57,163 @@ def _flash_attention_forward_swa_kernel( # 1. Phase 0: Sink blocks that are before the sliding window # 2. Phase 1: Off-Diagonal Blocks (within the window) # 3. Phase 2: Diagonal Blocks - pass + + # Cast Q once + q_block = q_block.to(tl.float32) + q_start = q_block_idx * BLOCK_M + win_left = q_start - (WINDOW_SIZE - 1) + window_start = tl.maximum(0, win_left) + + diag_start = q_block_idx * BLOCK_M + + # Phase 0: Attetion sink only + for start_n in range(0, SINK_SIZE, BLOCK_N): # We use the whole matrix, but only the first SINK_SIZE columns which can be > 1. So not necessarily lower triangular!! + #Load K + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # 2. Compute the attention scores (S_ij). + k_block = k_block.to(tl.float32) + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + v_block = v_block.to(tl.float32) + + # Masks + sink_cols = (k_offsets[None, :] < SINK_SIZE) + causal = (q_offsets[:, None] >= k_offsets[None, :]) + #valid = (q_offsets[:, None] < SEQ_LEN) & (k_offsets[None, :] < SEQ_LEN) + mask = sink_cols & causal #Causal is needed, unlike Phase 1, because we use the whole matrix + + s_ij = tl.where(mask, s_ij, -float('inf')) + + # Row has anything valid in this tile? + row_has = tl.max(mask, axis=1) > 0 + + # Online softmax update + m_ij = tl.max(s_ij, axis=1) + # Only update rows that have something valid + m_new = tl.where(row_has, tl.maximum(m_i, m_ij), m_i) + # Calculate scale factor only for rows that have something valid + scale_factor = tl.where(row_has, tl.exp2(m_i - m_new), 1.0) + + # Probabilities only for rows that have something valid + p_ij = tl.where(row_has[:, None], tl.exp2(s_ij - m_new[:, None]), 0.0) + + acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block) + l_i = l_i * scale_factor + tl.sum(p_ij, axis=1) + m_i = m_new + + # Phase 1: Off-Diagonal Blocks (within the window), excl sinks + for start_n in range(window_start, q_block_idx * BLOCK_M, BLOCK_N): + # Load K + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # 2. Compute the attention scores (S_ij). + k_block = k_block.to(tl.float32) + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + v_block = v_block.to(tl.float32) + + # EXCLUDE sinks (already handled in Phase 0) + non_sink = k_offsets[None, :] >= SINK_SIZE + + # Sliding window mask + dist = q_offsets[:, None] - k_offsets[None, :] #(BLOCK_M, BLOCK_N) + window_mask = (dist >= 0) & (dist < WINDOW_SIZE) + + # Validity mask + # valid_mask = (q_offsets[:, None] < SEQ_LEN) & (k_offsets[None, :] < SEQ_LEN) + + # Prevent overlap with diagonal tile: + pre_diag_mask = k_offsets[None, :] < diag_start + + # Combine masks + mask = window_mask & pre_diag_mask & non_sink #valid not needed + s_ij = tl.where(mask, s_ij, -float('inf')) + + # Row has anything valid in this tile? + row_has = tl.max(mask, axis=1) > 0 + + # online softmax update + m_ij = tl.max(s_ij, axis=1) + # Only update rows that have something valid + m_new = tl.where(row_has, tl.maximum(m_i, m_ij), m_i) + # Calculate scale factor only for rows that have something valid + scale_factor = tl.where(row_has, tl.exp2(m_i - m_new), 1.0) + + # Probabilities only for rows that have something valid + p_ij = tl.where(row_has[:, None], tl.exp2(s_ij - m_new[:, None]), 0.0) + + acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block) + l_i = l_i * scale_factor + tl.sum(p_ij, axis=1) + m_i = m_new + + # Phase 2: Diagonal Blocks + diag_start = q_block_idx * BLOCK_M + for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): + # Load K + k_offsets = start_n + tl.arange(0, BLOCK_N) # (BLOCK_N,) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # 2. Compute the attention scores (S_ij). + k_block = k_block.to(tl.float32) + s_ij = tl.dot(q_block, k_block) + s_ij *= qk_scale + v_block = v_block.to(tl.float32) + + # NON sink mask + non_sink = k_offsets[None, :] >= SINK_SIZE + + # Sliding window mask + # dist = q_offsets[:, None] - k_offsets[None, :] #(BLOCK_M, BLOCK_N) + # window_mask = (dist >= 0) & (dist < WINDOW_SIZE) + + # Combine masks + causal = q_offsets[:, None] >= k_offsets[None, :] #Lower triangle true + # valid = (q_offsets[:, None] < SEQ_LEN) & (k_offsets[None, :] < SEQ_LEN) + mask = causal & non_sink #window and valid not needed for diag + + # Apply mask BEFORE tile max so future tokens don't affect m_i + s_ij = tl.where(mask, s_ij, -float("inf")) + + # Row has anything valid in this tile? + row_has = tl.max(mask, axis=1) > 0 + + # online softmax update + m_ij = tl.max(s_ij, axis=1) + # Only update rows that have something valid + m_new = tl.where(row_has, tl.maximum(m_i, m_ij), m_i) + # Calculate scale factor only for rows that have something valid + scale_factor = tl.where(row_has, tl.exp2(m_i - m_new), 1.0) + + # Probabilities only for rows that have something valid + p_ij = tl.where(row_has[:, None], tl.exp2(s_ij - m_new[:, None]), 0.0) + + acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block) + l_i = l_i * scale_factor + tl.sum(p_ij, axis=1) + m_i = m_new # --- END OF STUDENT IMPLEMENTATION --- # 4. Normalize and write the final output block. diff --git a/problem_8.py b/problem_8.py index 8e63578..f1353b6 100644 --- a/problem_8.py +++ b/problem_8.py @@ -5,6 +5,347 @@ import math from typing import Optional +""" +Valid + Causal Masking is applied to both phases and both forward and backward kernels. +""" + +@triton.jit +def _flash_attention_forward_gqa_kernel( + # Pointers to Tensors + Q_ptr, K_ptr, V_ptr, O_ptr, M_ptr, # <-- add M_ptr + # Stride information for tensors + q_stride_b, q_stride_h, q_stride_s, + k_stride_b, k_stride_h, k_stride_s, + v_stride_b, v_stride_h, v_stride_s, + m_stride_b, m_stride_h, m_stride_s, # <-- add M strides + # Kernel parameters + softmax_scale, + SEQ_LEN, + N_Q_HEADS, + N_KV_HEADS, + # Constexpr tile sizes + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Triton kernel template for the forward pass of causal FlashAttention with GQA. + """ + q_block_idx = tl.program_id(axis=0) + batch_head_idx = tl.program_id(axis=1) + + batch_idx = batch_head_idx // N_Q_HEADS + q_head_idx = batch_head_idx % N_Q_HEADS + + # 1. GQA + group_size = N_Q_HEADS // N_KV_HEADS + kv_head_idx = q_head_idx // group_size + + # 2. Initialize accumulators in SRAM. + m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + # 3. Load the block of queries (Q_i). + q_offsets = (q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M)) + q_ptrs = Q_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + \ + (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0) + + qk_scale = softmax_scale * 1.44269504 + + # --- Phase 1: Off-Diagonal Blocks --- + for start_n in range(0, q_block_idx * BLOCK_M, BLOCK_N): + # Load K_j + k_offsets = start_n + tl.arange(0, BLOCK_N) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + # Compute the attention scores (S_ij). + s_ij = tl.dot(q_block.to(tl.float32), k_block.to(tl.float32)) + s_ij *= qk_scale + + v_block = v_block.to(tl.float32) + + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) + + scale_factor = tl.exp2(m_i - m_new) + acc = acc * scale_factor[:, None] + l_i = l_i * scale_factor + p_ij = tl.exp2(s_ij - m_new[:, None]) + acc = acc + tl.dot(p_ij, v_block) + l_i = l_i + tl.sum(p_ij, axis=1) + m_i = m_new + + # --- Phase 2: Diagonal Blocks --- + diag_start = q_block_idx * BLOCK_M + for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): + # Load K_j + k_offsets = start_n + tl.arange(0, BLOCK_N) # (BLOCK_N,) + k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0) + + # Load V_j + v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0) + + s_ij = tl.dot(q_block.to(tl.float32), k_block.to(tl.float32)) + s_ij *= qk_scale + + v_block = v_block.to(tl.float32) + + # Build mask + # Lower triangle true + causal = q_offsets[:, None] >= k_offsets[None, :] + valid = (q_offsets[:, None] < SEQ_LEN) & (k_offsets[None, :] < SEQ_LEN) + mask = causal & valid + + # Apply mask BEFORE tile max so future tokens don't affect m_i + s_ij = tl.where(mask, s_ij, -float("inf")) + + # Online softmax update + m_ij = tl.max(s_ij, axis=1) + m_new = tl.maximum(m_i, m_ij) + scale_factor = tl.exp2(m_i - m_new) + + p_ij = tl.exp2(s_ij - m_new[:, None]) + + acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block) + l_i = l_i * scale_factor + tl.sum(p_ij, axis=1) + m_i = m_new + + # 4. Normalize and write the final output block. + l_i_safe = l_i[:, None] + 1e-6 + acc = acc / l_i_safe + + o_ptrs = O_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + \ + (q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + tl.store(o_ptrs, acc.to(O_ptr.dtype.element_ty), mask=q_offsets[:, None] < SEQ_LEN) + + lse_log2 = m_i + tl.log2(l_i + 1e-6) + m_ptrs = M_ptr + batch_idx * m_stride_b + q_head_idx * m_stride_h + q_offsets * m_stride_s + tl.store(m_ptrs, lse_log2, mask=q_offsets < SEQ_LEN) + +@triton.jit +def _flash_attention_backward_gqa_kernel( + # Pointers + Q_ptr, K_ptr, V_ptr, O_ptr, dO_ptr, M_ptr, + dQ_ptr, dK_ptr, dV_ptr, + # Strides + q_stride_b, q_stride_h, q_stride_s, + k_stride_b, k_stride_h, k_stride_s, + v_stride_b, v_stride_h, v_stride_s, + o_stride_b, o_stride_h, o_stride_s, + M_stride_b, M_stride_h, M_stride_s, + do_stride_b, do_stride_h, do_stride_s, + dq_stride_b, dq_stride_h, dq_stride_s, + dk_stride_b, dk_stride_h, dk_stride_s, + dv_stride_b, dv_stride_h, dv_stride_s, + # Params + softmax_scale, SEQ_LEN, N_Q_HEADS, N_KV_HEADS, + # Constexpr + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # Program ids + q_block_idx = tl.program_id(axis=0) + batch_head_id = tl.program_id(axis=1) + + # Map to batch + q-head + b_idx = batch_head_id // N_Q_HEADS + q_head_idx = batch_head_id % N_Q_HEADS + + # GQA mapping + group_size = N_Q_HEADS // N_KV_HEADS + kv_head_idx = q_head_idx // group_size + + # Row offsets for this block + q_offsets = q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M) + row_mask = q_offsets < SEQ_LEN + col = tl.arange(0, HEAD_DIM) + + # Load ptrs + q_ptrs = Q_ptr + b_idx * q_stride_b + q_head_idx * q_stride_h + (q_offsets[:, None] * q_stride_s + col[None, :]) + do_ptrs = dO_ptr + b_idx * do_stride_b + q_head_idx * do_stride_h + (q_offsets[:, None] * do_stride_s + col[None, :]) + o_ptrs = O_ptr + b_idx * o_stride_b + q_head_idx * o_stride_h + (q_offsets[:, None] * o_stride_s + col[None, :]) + M_ptrs = M_ptr + b_idx * M_stride_b + q_head_idx * M_stride_h + q_offsets * M_stride_s + lse_log2_row = tl.load(M_ptrs, mask=row_mask, other=-float('inf')) + + q_block = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32) + do_block = tl.load(do_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32) + o_block = tl.load(o_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32) + + qk_scale = softmax_scale * 1.44269504 + + # Pass 0: delta = sum(dO * O) per row + delta = tl.sum(do_block * o_block, axis=1) + + # Pass 1: recompute m_i and l_i (online) over causal prefix + m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + # Off-diagonal tiles (strictly before this block) + for start_n in range(0, q_block_idx * BLOCK_M, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + kv_valid = k_offsets < SEQ_LEN + + k_ptrs = K_ptr + b_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_blk = tl.load(k_ptrs, mask=kv_valid[None, :], other=0.0).to(tl.float32) + + # Scores + s = tl.dot(q_block, k_blk) * qk_scale + + causal = q_offsets[:, None] >= k_offsets[None, :] + valid = row_mask[:, None] & kv_valid[None, :] + s = tl.where(causal & valid, s, -float("inf")) + + m_ij = tl.max(s, axis=1) + m_new = tl.maximum(m_i, m_ij) + alpha = tl.exp2(m_i - m_new) + + p = tl.exp2(s - m_new[:, None]) + l_i = l_i * alpha + tl.sum(p, axis=1) + m_i = m_new + + # Diagonal block range + diag_start = q_block_idx * BLOCK_M + for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + kv_valid = k_offsets < SEQ_LEN + + k_ptrs = K_ptr + b_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_blk = tl.load(k_ptrs, mask=kv_valid[None, :], other=0.0).to(tl.float32) + + s = tl.dot(q_block, k_blk) * qk_scale + + causal = q_offsets[:, None] >= k_offsets[None, :] + #valid = row_mask[:, None] & kv_valid[None, :] + s = tl.where(causal, s, -float("inf")) + + m_ij = tl.max(s, axis=1) + m_new = tl.maximum(m_i, m_ij) + alpha = tl.exp2(m_i - m_new) + + p = tl.exp2(s - m_new[:, None]) + l_i = l_i * alpha + tl.sum(p, axis=1) + m_i = m_new + + # Pass 2: compute grads + dQ_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + # off-diagonal tiles + for start_n in range(0, q_block_idx * BLOCK_M, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + kv_valid = k_offsets < SEQ_LEN + + # K for scores: (HEAD_DIM, BLOCK_N) + k_cols_ptrs = K_ptr + b_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_cols = tl.load(k_cols_ptrs, mask=kv_valid[None, :], other=0.0).to(tl.float32) + + # K for dQ: (BLOCK_N, HEAD_DIM) + k_rows_ptrs = K_ptr + b_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[:, None] * k_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + k_rows = tl.load(k_rows_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + + # V: (BLOCK_N, HEAD_DIM) + v_ptrs = V_ptr + b_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_blk = tl.load(v_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + + # Scores (BLOCK_M, BLOCK_N) + s = tl.dot(q_block, k_cols) * qk_scale + causal = q_offsets[:, None] >= k_offsets[None, :] + valid = row_mask[:, None] & kv_valid[None, :] + mask = causal & valid + s = tl.where(mask, s, -float("inf")) + + p = tl.exp2(s - lse_log2_row[:, None]) + p = tl.where(mask, p, 0.0) + + dV_tile = tl.dot(p.T, do_block) # (BLOCK_N, HEAD_DIM) + dp = tl.dot(do_block, v_blk.T) # (BLOCK_M, BLOCK_N) + # dS = dp - p * delta[:, None] # (BLOCK_M, BLOCK_N) - doesn't work + dS = (dp - delta[:, None]) * p + + dQ_acc += tl.dot(dS, k_rows) * softmax_scale + dK_tile = tl.dot(dS.T, q_block) * softmax_scale + + dk_ptrs = dK_ptr + b_idx * dk_stride_b + kv_head_idx * dk_stride_h + \ + (k_offsets[:, None] * dk_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + dv_ptrs = dV_ptr + b_idx * dv_stride_b + kv_head_idx * dv_stride_h + \ + (k_offsets[:, None] * dv_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + + # Doing this means casting the atomic values to bf16, it fails test cases + # Instead, create fp32 accumulators in HBM + # tl.atomic_add(dk_ptrs, dK_tile.to(dk_ptrs.dtype.element_ty), mask=kv_valid[:, None]) + # tl.atomic_add(dv_ptrs, dV_tile.to(dv_ptrs.dtype.element_ty), mask=kv_valid[:, None]) + tl.atomic_add(dk_ptrs, dK_tile, mask=kv_valid[:, None]) + tl.atomic_add(dv_ptrs, dV_tile, mask=kv_valid[:, None]) + + # Diagonal tiles + for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + kv_valid = k_offsets < SEQ_LEN + + # K for scores: (HEAD_DIM, BLOCK_N) + k_cols_ptrs = K_ptr + b_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None]) + k_cols = tl.load(k_cols_ptrs, mask=kv_valid[None, :], other=0.0).to(tl.float32) + + # K for dQ: (BLOCK_N, HEAD_DIM) + k_rows_ptrs = K_ptr + b_idx * k_stride_b + kv_head_idx * k_stride_h + \ + (k_offsets[:, None] * k_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + k_rows = tl.load(k_rows_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + + # V: (BLOCK_N, HEAD_DIM) + v_ptrs = V_ptr + b_idx * v_stride_b + kv_head_idx * v_stride_h + \ + (k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + v_blk = tl.load(v_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + + s = tl.dot(q_block, k_cols) * qk_scale + causal = q_offsets[:, None] >= k_offsets[None, :] + valid = row_mask[:, None] & kv_valid[None, :] + mask = causal & valid + s = tl.where(mask, s, -float("inf")) + + p = tl.exp2(s - lse_log2_row[:, None]) + p = tl.where(mask, p, 0.0) + + dV_tile = tl.dot(p.T, do_block) + dp = tl.dot(do_block, v_blk.T) + # dS = dp - p * delta[:, None] doesn't work + dS = (dp - delta[:, None]) * p + + dQ_acc += tl.dot(dS, k_rows) * softmax_scale + dK_tile = tl.dot(dS.T, q_block) * softmax_scale + + dk_ptrs = dK_ptr + b_idx * dk_stride_b + kv_head_idx * dk_stride_h + \ + (k_offsets[:, None] * dk_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + dv_ptrs = dV_ptr + b_idx * dv_stride_b + kv_head_idx * dv_stride_h + \ + (k_offsets[:, None] * dv_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + + tl.atomic_add(dk_ptrs, dK_tile, mask=kv_valid[:, None]) + tl.atomic_add(dv_ptrs, dV_tile, mask=kv_valid[:, None]) + + # Store dQ (unique writer) + dq_ptrs = dQ_ptr + b_idx * dq_stride_b + q_head_idx * dq_stride_h + \ + (q_offsets[:, None] * dq_stride_s + tl.arange(0, HEAD_DIM)[None, :]) + tl.store(dq_ptrs, dQ_acc.to(dQ_ptr.dtype.element_ty), mask=row_mask[:, None]) + + class FlashAttention2Function(torch.autograd.Function): """ Triton implementation of FlashAttention-2, supports causal attention and GQA. @@ -27,6 +368,20 @@ def forward(ctx, q, k, v, is_causal=True, softmax_scale: Optional[float] = None) grid = (triton.cdiv(seq_len, BLOCK_M), batch * n_heads) # TODO: Add your forward kernel here + _flash_attention_forward_gqa_kernel[grid]( + q, k, v, o, M, + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + M.stride(0), M.stride(1), M.stride(2), + softmax_scale, + seq_len, + n_heads, + n_kv_heads, + HEAD_DIM=head_dim, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) ctx.save_for_backward(q, k, v, o, M) ctx.softmax_scale = softmax_scale @@ -43,7 +398,11 @@ def backward(ctx, do): dq = torch.empty_like(q) dk = torch.zeros_like(k) dv = torch.zeros_like(v) - + # DO NOT CHANGE THE DTYPES IN THE TEMPLATE + # dq = torch.empty_like(q, dtype=torch.float32) + # dk = torch.zeros_like(k, dtype=torch.float32) + # dv = torch.zeros_like(v, dtype=torch.float32) + # [OPTIONAL BONUS] STUDENT IMPLEMENTATION REQUIRED # Implement the Triton backward kernel for GQA from scratch. # You should: @@ -51,7 +410,39 @@ def backward(ctx, do): # 2. Recompute attention probabilities P = softmax(QK^T) # 3. Use delta + dO to accumulate gradients for dq, dk, dv # 4. Respect GQA mapping and causal mask - + + # Instead, specify the dtype on the accumulation tensors + dk_acc = torch.zeros_like(k, dtype=torch.float32) + dv_acc = torch.zeros_like(v, dtype=torch.float32) + + BLOCK_M, BLOCK_N = 128, 64 + grid = (triton.cdiv(seq_len, BLOCK_M), batch * n_heads) + + _flash_attention_backward_gqa_kernel[grid]( + # Pointers + q, k, v, o, do, M, + dq, dk_acc, dv_acc, + # Strides (q, k, v, o, dO, dQ, dK, dV) + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + o.stride(0), o.stride(1), o.stride(2), + M.stride(0), M.stride(1), M.stride(2), + do.stride(0), do.stride(1), do.stride(2), + dq.stride(0), dq.stride(1), dq.stride(2), + dk.stride(0), dk.stride(1), dk.stride(2), + dv.stride(0), dv.stride(1), dv.stride(2), + # Params + ctx.softmax_scale, seq_len, n_heads, n_kv_heads, + # Constexpr + HEAD_DIM=head_dim, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + ) + + # Write them back to HBM + # copy_ is in place + dk.copy_(dk_acc) + dv.copy_(dv_acc) + return dq, dk.to(k.dtype), dv.to(v.dtype), None, None diff --git a/problem_9.py b/problem_9.py index 126b416..62f4ecf 100644 --- a/problem_9.py +++ b/problem_9.py @@ -4,6 +4,25 @@ import math from typing import Optional +""" +Why is it that non_sink & in_window works for phase 2 too? +Although our implementation does non_sink & causal? +Because the window is usually set up to only include causal positions +in the diagonal block. + +Naming convetion (diff from problem 7) +Kc is [D,N] - block of key vectors arranged as cols +Vt is [N,D] - block of value vectors, transposed for matmul + +BATCH_SIZE is passed to backward kernel, but not used +But could be used for bounds checking + +Phase 0: sink_cols & causal +Phase 1: non_sink & window_mask & pre_diag_mask +Phase 2: non_sink & causal +""" + + @triton.jit def _flash_attention_forward_swa_kernel( # Pointers to Tensors @@ -26,7 +45,150 @@ def _flash_attention_forward_swa_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - pass + # Program ids + q_block_idx = tl.program_id(axis=0) + batch_head_id = tl.program_id(axis=1) + batch_idx = batch_head_id // N_Q_HEADS + q_head_idx = batch_head_id % N_Q_HEADS + + # GQA mapping + group_size = N_Q_HEADS // N_KV_HEADS + kv_head_idx = q_head_idx // group_size + + # Accumulators + m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + # Load Q + q_offsets = q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M) + row_mask = q_offsets < SEQ_LEN + cols = tl.arange(0, HEAD_DIM) + + q_ptrs = Q_ptr + batch_idx * q_stride_b + q_head_idx * q_stride_h + \ + (q_offsets[:, None] * q_stride_s + cols[None, :]) + q_block = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32) + + qk_scale = softmax_scale * 1.44269504 + + # Sliding window + q_start = q_block_idx * BLOCK_M + win_left = q_start - (WINDOW_SIZE - 1) + window_start = tl.maximum(0, win_left) + diag_start = q_block_idx * BLOCK_M + + # Phase 0: Sink Tiles + for start_n in range(0, SINK_SIZE, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + kv_valid = k_offsets < SEQ_LEN + + Kc_ptrs = K_ptr + batch_idx*k_stride_b + kv_head_idx * \ + k_stride_h + (k_offsets[None, :]*k_stride_s + cols[:, None]) + V_ptrs = V_ptr + batch_idx*v_stride_b + kv_head_idx * \ + v_stride_h + (k_offsets[:, None]*v_stride_s + cols[None, :]) + + Kc = tl.load(Kc_ptrs, mask=kv_valid[None, :], other=0.0).to(tl.float32) + Vt = tl.load(V_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + + S = tl.dot(q_block, Kc) * qk_scale + causal = q_offsets[:, None] >= k_offsets[None, :] + sink_cols = k_offsets[None, :] < SINK_SIZE + #valid = row_mask[:, None] & kv_valid[None, :] + mask = sink_cols & causal + + S = tl.where(mask, S, -float('inf')) + + row_has = tl.max(mask, axis=1) > 0 + m_ij = tl.max(S, axis=1) + m_new = tl.where(row_has, tl.maximum(m_i, m_ij), m_i) + alpha = tl.where(row_has, tl.exp2(m_i - m_new), 1.0) + P = tl.where(row_has[:, None], tl.exp2(S - m_new[:, None]), 0.0) + + acc = acc * alpha[:, None] + tl.dot(P, Vt) + l_i = l_i * alpha + tl.sum(P, axis=1) + m_i = m_new + + # Phase 1: Off-Diagonal Tiles + for start_n in range(window_start, diag_start, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + kv_valid = k_offsets < SEQ_LEN + + Kc_ptrs = K_ptr + batch_idx*k_stride_b + kv_head_idx * \ + k_stride_h + (k_offsets[None, :]*k_stride_s + cols[:, None]) + V_ptrs = V_ptr + batch_idx*v_stride_b + kv_head_idx * \ + v_stride_h + (k_offsets[:, None]*v_stride_s + cols[None, :]) + + Kc = tl.load(Kc_ptrs, mask=kv_valid[None, :], other=0.0).to(tl.float32) + Vt = tl.load(V_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + + S = tl.dot(q_block, Kc) * qk_scale + + dist = q_offsets[:, None] - k_offsets[None, :] + in_window = (dist >= 0) & (dist < WINDOW_SIZE) + pre_diag = k_offsets[None, :] < diag_start + non_sink = k_offsets[None, :] >= SINK_SIZE + # valid = row_mask[:,None] & kv_valid[None,:] + mask = pre_diag & non_sink & in_window + + S = tl.where(mask, S, -float('inf')) + + row_has = tl.max(mask, axis=1) > 0 + m_ij = tl.max(S, axis=1) + m_new = tl.where(row_has, tl.maximum(m_i, m_ij), m_i) + alpha = tl.where(row_has, tl.exp2(m_i - m_new), 1.0) + P = tl.where(row_has[:, None], tl.exp2(S - m_new[:, None]), 0.0) + + acc = acc*alpha[:, None] + tl.dot(P, Vt) + l_i = l_i*alpha + tl.sum(P, axis=1) + m_i = m_new + + # Phase 2: Diagonal Tiles + for start_n in range(diag_start, (q_block_idx+1)*BLOCK_M, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + kv_valid = k_offsets < SEQ_LEN + + Kc_ptrs = K_ptr + batch_idx*k_stride_b + kv_head_idx * \ + k_stride_h + (k_offsets[None, :]*k_stride_s + cols[:, None]) + V_ptrs = V_ptr + batch_idx*v_stride_b + kv_head_idx * \ + v_stride_h + (k_offsets[:, None]*v_stride_s + cols[None, :]) + + Kc = tl.load(Kc_ptrs, mask=kv_valid[None, :], other=0.0).to(tl.float32) + Vt = tl.load(V_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + + S = tl.dot(q_block, Kc) * qk_scale + + # dist = q_offsets[:,None] - k_offsets[None,:] + # in_window = (dist >= 0) & (dist < WINDOW_SIZE) + # valid = row_mask[:,None] & kv_valid[None,:] + non_sink = k_offsets[None, :] >= SINK_SIZE + causal = q_offsets[:, None] >= k_offsets[None, :] + mask = non_sink & causal + + S = tl.where(mask, S, -float('inf')) + + row_has = tl.max(mask, axis=1) > 0 + m_ij = tl.max(S, axis=1) + m_new = tl.where(row_has, tl.maximum(m_i, m_ij), m_i) + alpha = tl.where(row_has, tl.exp2(m_i - m_new), 1.0) + P = tl.where(row_has[:, None], tl.exp2(S - m_new[:, None]), 0.0) + + acc = acc * alpha[:, None] + tl.dot(P, Vt) + l_i = l_i * alpha + tl.sum(P, axis=1) + m_i = m_new + + # Normalize + l_i_safe = tl.where(l_i == 0, 1.0, l_i) + O = acc / l_i_safe[:, None] + o_ptrs = O_ptr + batch_idx*o_stride_b + q_head_idx * \ + o_stride_h + (q_offsets[:, None] * o_stride_s + cols[None, :]) + tl.store(o_ptrs, O.to(O_ptr.dtype.element_ty), mask=row_mask[:, None]) + + # Store per-row log2sumexp in M + lse_log2 = m_i + tl.log2(l_i + 1e-6) + m_ptrs = M_ptr + batch_idx * m_stride_b + \ + q_head_idx * m_stride_h + q_offsets*m_stride_s + tl.store(m_ptrs, lse_log2, mask=row_mask) + @triton.jit def _flash_attention_backward_swa_kernel( @@ -56,7 +218,173 @@ def _flash_attention_backward_swa_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - pass + # 1) Ids + q_block_idx = tl.program_id(axis=0) + batch_head_idx = tl.program_id(axis=1) + batch_idx = batch_head_idx // N_Q_HEADS + q_head_idx = batch_head_idx % N_Q_HEADS + + # 2) GQA mapping + group_size = N_Q_HEADS // N_KV_HEADS + kv_head_idx = q_head_idx // group_size + + # 3) Indices + q_offsets = q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M) + row_mask = q_offsets < SEQ_LEN + cols = tl.arange(0, HEAD_DIM) + + # 4) load blocks + q_ptrs = Q_ptr + batch_idx * q_stride_b + q_head_idx * \ + q_stride_h + (q_offsets[:, None] * q_stride_s + cols[None, :]) + do_ptrs = dO_ptr + batch_idx * do_stride_b + q_head_idx * \ + do_stride_h + (q_offsets[:, None] * do_stride_s + cols[None, :]) + m_ptrs = M_ptr + batch_idx * m_stride_b + \ + q_head_idx * m_stride_h + q_offsets * m_stride_s + d_ptrs = D_ptr + batch_idx * d_stride_b + \ + q_head_idx * d_stride_h + q_offsets * d_stride_s + + Q = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32) + dO = tl.load(do_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32) + LSE = tl.load(m_ptrs, mask=row_mask, other=-float('inf')).to(tl.float32) + delta = tl.load(d_ptrs, mask=row_mask, other=0.0).to(tl.float32) + + qk_scale = softmax_scale * 1.44269504 + dQ_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + # Sliding window + q_start = q_block_idx * BLOCK_M + win_left = q_start - (WINDOW_SIZE - 1) + window_start = tl.maximum(0, win_left) + diag_start = q_block_idx * BLOCK_M + + # Phase 0 + for start_n in range(0, SINK_SIZE, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + kv_valid = k_offsets < SEQ_LEN + + Kc_ptrs = K_ptr + batch_idx*k_stride_b + kv_head_idx * \ + k_stride_h + (k_offsets[None, :]*k_stride_s + cols[:, None]) + Kr_ptrs = K_ptr + batch_idx*k_stride_b + kv_head_idx * \ + k_stride_h + (k_offsets[:, None]*k_stride_s + cols[None, :]) + Vr_ptrs = V_ptr + batch_idx*v_stride_b + kv_head_idx * \ + v_stride_h + (k_offsets[:, None]*v_stride_s + cols[None, :]) + + Kc = tl.load(Kc_ptrs, mask=kv_valid[None, :], other=0.0).to(tl.float32) + Kr = tl.load(Kr_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + Vr = tl.load(Vr_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + + S = tl.dot(Q, Kc) * qk_scale + causal = q_offsets[:, None] >= k_offsets[None, :] + sinkcol = k_offsets[None, :] < SINK_SIZE + # valid = row_mask[:,None] & kv_valid[None,:] + mask = sinkcol & causal + + S = tl.where(mask, S, -float('inf')) + P = tl.where(mask, tl.exp2(S - LSE[:, None]), 0.0) + + dV = tl.dot(P.T, dO) + dp = tl.dot(dO, Vr.T) + dS = (dp - delta[:, None]) * P + + dQ_acc += tl.dot(dS, Kr) * softmax_scale + dK = tl.dot(dS.T, Q) * softmax_scale + + dk_ptrs = dK_ptr + batch_idx * dk_stride_b + kv_head_idx * \ + dk_stride_h + (k_offsets[:, None]*dk_stride_s + cols[None, :]) + dv_ptrs = dV_ptr + batch_idx * dv_stride_b + kv_head_idx * \ + dv_stride_h + (k_offsets[:, None]*dv_stride_s + cols[None, :]) + tl.atomic_add(dk_ptrs, dK, mask=kv_valid[:, None]) + tl.atomic_add(dv_ptrs, dV, mask=kv_valid[:, None]) + + # Phase 1: Off-Diagonal Tiles + for start_n in range(window_start, diag_start, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + kv_valid = k_offsets < SEQ_LEN + + Kc_ptrs = K_ptr + batch_idx*k_stride_b + kv_head_idx * \ + k_stride_h + (k_offsets[None, :] * k_stride_s + cols[:, None]) + Kr_ptrs = K_ptr + batch_idx*k_stride_b + kv_head_idx * \ + k_stride_h + (k_offsets[:, None] * k_stride_s + cols[None, :]) + Vr_ptrs = V_ptr + batch_idx*v_stride_b + kv_head_idx * \ + v_stride_h + (k_offsets[:, None] * v_stride_s + cols[None, :]) + + Kc = tl.load(Kc_ptrs, mask=kv_valid[None, :], other=0.0).to(tl.float32) + Kr = tl.load(Kr_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + Vr = tl.load(Vr_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + + S = tl.dot(Q, Kc) * qk_scale + + dist = q_offsets[:, None] - k_offsets[None, :] + in_window = (dist >= 0) & (dist < WINDOW_SIZE) + pre_diag = k_offsets[None, :] < diag_start + non_sink = k_offsets[None, :] >= SINK_SIZE + # valid = row_mask[:,None] & kv_valid[None,:] + mask = pre_diag & non_sink & in_window + + S = tl.where(mask, S, -float('inf')) + P = tl.where(mask, tl.exp2(S - LSE[:, None]), 0.0) + + dV = tl.dot(P.T, dO) + dp = tl.dot(dO, Vr.T) + dS = (dp - delta[:, None]) * P + + dQ_acc += tl.dot(dS, Kr) * softmax_scale + dK = tl.dot(dS.T, Q) * softmax_scale + + dk_ptrs = dK_ptr + batch_idx*dk_stride_b + kv_head_idx * \ + dk_stride_h + (k_offsets[:, None] * dk_stride_s + cols[None, :]) + dv_ptrs = dV_ptr + batch_idx*dv_stride_b + kv_head_idx * \ + dv_stride_h + (k_offsets[:, None] * dv_stride_s + cols[None, :]) + tl.atomic_add(dk_ptrs, dK, mask=kv_valid[:, None]) + tl.atomic_add(dv_ptrs, dV, mask=kv_valid[:, None]) + + # Phase 2: Diagonal Tiles + for start_n in range(diag_start, (q_block_idx+1)*BLOCK_M, BLOCK_N): + k_offsets = start_n + tl.arange(0, BLOCK_N) + kv_valid = k_offsets < SEQ_LEN + + Kc_ptrs = K_ptr + batch_idx*k_stride_b + kv_head_idx * \ + k_stride_h + (k_offsets[None, :] * k_stride_s + cols[:, None]) + Kr_ptrs = K_ptr + batch_idx*k_stride_b + kv_head_idx * \ + k_stride_h + (k_offsets[:, None] * k_stride_s + cols[None, :]) + Vr_ptrs = V_ptr + batch_idx*v_stride_b + kv_head_idx * \ + v_stride_h + (k_offsets[:, None] * v_stride_s + cols[None, :]) + + Kc = tl.load(Kc_ptrs, mask=kv_valid[None, :], other=0.0).to(tl.float32) + Kr = tl.load(Kr_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + Vr = tl.load(Vr_ptrs, mask=kv_valid[:, None], other=0.0).to(tl.float32) + + S = tl.dot(Q, Kc) * qk_scale + + # dist = q_offsets[:,None] - k_offsets[None,:] + # in_window = (dist >= 0) & (dist < WINDOW_SIZE) + # valid = row_mask[:,None] & kv_valid[None,:] + non_sink = k_offsets[None, :] >= SINK_SIZE + causal = q_offsets[:, None] >= k_offsets[None, :] + mask = non_sink & causal + + S = tl.where(mask, S, -float('inf')) + P = tl.where(mask, tl.exp2(S - LSE[:, None]), 0.0) + + dV = tl.dot(P.T, dO) + dp = tl.dot(dO, Vr.T) + dS = (dp - delta[:, None]) * P + + dQ_acc += tl.dot(dS, Kr) * softmax_scale + dK = tl.dot(dS.T, Q) * softmax_scale + + dk_ptrs = dK_ptr + batch_idx * dk_stride_b + kv_head_idx * \ + dk_stride_h + (k_offsets[:, None] * dk_stride_s + cols[None, :]) + dv_ptrs = dV_ptr + batch_idx * dv_stride_b + kv_head_idx * \ + dv_stride_h + (k_offsets[:, None] * dv_stride_s + cols[None, :]) + tl.atomic_add(dk_ptrs, dK, mask=kv_valid[:, None]) + tl.atomic_add(dv_ptrs, dV, mask=kv_valid[:, None]) + + # 5) store dQ + dq_ptrs = dQ_ptr + batch_idx * dq_stride_b + q_head_idx * \ + dq_stride_h + (q_offsets[:, None] * dq_stride_s + cols[None, :]) + tl.store(dq_ptrs, dQ_acc.to(dQ_ptr.dtype.element_ty), mask=row_mask[:, None]) + class FlashSWDAWithSink(torch.autograd.Function): @staticmethod @@ -75,8 +403,8 @@ def forward(ctx, q, k, v, window_size, sink_size, is_causal=True, softmax_scale= assert n_q_heads % n_kv_heads == 0, "Number of query heads must be divisible by number of K/V heads" o = torch.empty_like(q) - M = torch.empty((batch, n_q_heads, seq_len), device=q.device, dtype=torch.float32) - + M = torch.empty((batch, n_q_heads, seq_len), + device=q.device, dtype=torch.float32) BLOCK_M, BLOCK_N = 128, 64 grid = (math.ceil(seq_len / BLOCK_M), batch * n_q_heads) @@ -118,10 +446,45 @@ def backward(ctx, do): dq = torch.empty_like(q) dk = torch.zeros_like(k) dv = torch.zeros_like(v) - + # TODO: Add your backward kernel here + dk_acc = torch.zeros_like(k, dtype=torch.float32) + dv_acc = torch.zeros_like(v, dtype=torch.float32) + + D = (do.float() * o.float()).sum(dim=-1) + + BLOCK_M, BLOCK_N = 128, 64 + grid = (triton.cdiv(seq_len, BLOCK_M), batch * n_q_heads) + + _flash_attention_backward_swa_kernel[grid]( + q, k, v, do, M, D, + dq, dk_acc, dv_acc, + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + do.stride(0), do.stride(1), do.stride(2), + M.stride(0), M.stride(1), M.stride(2), + D.stride(0), D.stride(1), D.stride(2), + dq.stride(0), dq.stride(1), dq.stride(2), + dk_acc.stride(0), dk_acc.stride(1), dk_acc.stride(2), + dv_acc.stride(0), dv_acc.stride(1), dv_acc.stride(2), + softmax_scale, + batch, + n_q_heads, + n_kv_heads, + seq_len, + WINDOW_SIZE=window_size, + SINK_SIZE=sink_size, + HEAD_DIM=head_dim, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + dk.copy_(dk_acc) + dv.copy_(dv_acc) return dq, dk.to(k.dtype), dv.to(v.dtype), None, None, None, None - + + def flash_swda_with_sink(q, k, v, window_size: int, sink_size: int = 0, is_causal: bool = True, scale: Optional[float] = None): - return FlashSWDAWithSink.apply(q, k, v, window_size, sink_size, is_causal, scale) \ No newline at end of file + return FlashSWDAWithSink.apply(q, k, v, window_size, sink_size, is_causal, scale)