Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions problem_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ class FlashAttention2Function(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, K, V, is_causal=False):
# Get dimensions from input tensors following the (B, H, N, D) convention
B, H, N_Q, D_H = Q.shape
_, _, N_K, _ = K.shape
B, H, N_Q, D_H = Q.shape # N_Q is the number of query tokens, D_H is the hidden dimension
_, _, N_K, _ = K.shape # N_K is the number of key tokens

# Define tile sizes
Q_TILE_SIZE = 128
Expand Down Expand Up @@ -41,9 +41,9 @@ def forward(ctx, Q, K, V, is_causal=False):
Q_tile = Q_bh[q_start:q_end, :]

# Initialize accumulators for this query tile
o_i = torch.zeros_like(Q_tile, dtype=Q.dtype)
l_i = torch.zeros(q_end - q_start, device=Q.device, dtype=torch.float32)
m_i = torch.full((q_end - q_start,), -float('inf'), device=Q.device, dtype=torch.float32)
o_i = torch.zeros_like(Q_tile, dtype=Q.dtype) #running weighted output - the actual attention output.
l_i = torch.zeros(q_end - q_start, device=Q.device, dtype=torch.float32) #running sum of exponentials (for narmalization)
m_i = torch.full((q_end - q_start,), -float('inf'), device=Q.device, dtype=torch.float32) #max value for trick: subtracting the max to prevent exp(x) from blowing up

# Inner loop over key/value tiles
for j in range(N_K_tiles):
Expand All @@ -53,26 +53,32 @@ def forward(ctx, Q, K, V, is_causal=False):
K_tile = K_bh[k_start:k_end, :]
V_tile = V_bh[k_start:k_end, :]

S_ij = (Q_tile @ K_tile.transpose(-1, -2)) * scale

S_ij = (Q_tile @ K_tile.transpose(-1, -2)) * scale
# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# 1. Apply causal masking if is_causal is True.
#
if is_causal:
q_idx = torch.arange(q_start, q_end, device=Q.device).unsqueeze(1) # (Tq, 1) #### [MAIN FIX] q_indx must be on the same device as S_ij, which is CUDA
k_idx = torch.arange(k_start, k_end, device=Q.device).unsqueeze(0) # (1, Tk) #### [MAIN FIX] k_indx must be on the same device as S_ij, which is CUDA
S_ij = S_ij.masked_fill( k_idx > q_idx, -float('inf')) #allow only k <= q [MAIN FIX] was k_inx < q_indx earlier
# 2. Compute the new running maximum
#
m_ij = torch.max(S_ij, dim = -1).values.to(torch.float32)
m_new = torch.maximum(m_i, m_ij)
# 3. Rescale the previous accumulators (o_i, l_i)
#
scale_factor = torch.exp(m_i - m_new)
o_i = o_i * scale_factor.unsqueeze(-1).to(o_i.dtype)
l_i = l_i * scale_factor
# 4. Compute the probabilities for the current tile, P_tilde_ij = exp(S_ij - m_new).
#
P_tilde_ij = torch.exp(S_ij.to(torch.float32) - m_new.unsqueeze(-1)) #### [MAIN FIX] unsqueeze this
# 5. Accumulate the current tile's contribution to the accumulators to update l_i and o_i
#
l_i = l_i + torch.sum(P_tilde_ij, dim=-1)
o_i = o_i + (P_tilde_ij @ V_tile.to(torch.float32)).to(o_i.dtype)
# 6. Update the running max for the next iteration

m_i = m_new
# --- END OF STUDENT IMPLEMENTATION ---

# After iterating through all key tiles, normalize the output
# This part is provided for you. It handles the final division safely.
l_i_reciprocal = torch.where(l_i > 0, 1.0 / l_i, 0)
l_i_reciprocal = torch.where(l_i > 0, 1.0 / l_i, 0) #
o_i_normalized = o_i * l_i_reciprocal.unsqueeze(-1)

L_tile = m_i + torch.log(l_i)
Expand Down
26 changes: 13 additions & 13 deletions problem_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,47 @@ def weighted_row_sum_kernel(
"""
# 1. Get the row index for the current program instance.
# Hint: Use tl.program_id(axis=0).
row_idx = ...
row_idx = tl.program_id(axis = 0)

# 2. Create a pointer to the start of the current row in the input tensor X.
# Hint: The offset depends on the row index and the number of columns (N_COLS).
row_start_ptr = ...
row_start_ptr = X_ptr + row_idx * N_COLS

# 3. Create a pointer for the output vector Y.
output_ptr = ...
output_ptr = Y_ptr + row_idx

# 4. Initialize an accumulator for the sum of the products for a block.
# This should be a block-sized tensor of zeros.
# Hint: Use tl.zeros with shape (BLOCK_SIZE,) and dtype tl.float32.
accumulator = ...
accumulator = tl.zeros((BLOCK_SIZE,), dtype = tl.float32)

# 5. Iterate over the columns of the row in blocks of BLOCK_SIZE.
# Hint: Use a for loop with tl.cdiv(N_COLS, BLOCK_SIZE).
for col_block_start in range(0, ...):
for col_block_start in range(0, tl.cdiv(N_COLS, BLOCK_SIZE)):
# - Calculate the offsets for the current block of columns.
# Hint: Start from the block's beginning and add tl.arange(0, BLOCK_SIZE).
col_offsets = ...
col_offsets = col_block_start * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

# - Create a mask to prevent out-of-bounds memory access for the last block.
# Hint: Compare col_offsets with N_COLS.
mask = ...
mask = col_offsets < N_COLS

# - Load a block of data from X and W safely using the mask.
# Hint: Use tl.load with the appropriate pointers, offsets, and mask.
# Use `other=0.0` to handle out-of-bounds elements.
x_chunk = tl.load(...)
w_chunk = tl.load(...)
x_chunk = tl.load(row_start_ptr + col_offsets, mask = mask, other = 0.0)
w_chunk = tl.load(W_ptr + col_offsets, mask = mask, other = 0.0)

# - Compute the element-wise product and add it to the accumulator.
accumulator += ...
accumulator += x_chunk * w_chunk

# 6. Reduce the block-sized accumulator to a single scalar value after the loop.
# Hint: Use tl.sum().
final_sum = ...
final_sum = tl.sum(accumulator, axis = 0)

# 7. Store the final accumulated sum to the output tensor Y.
# Hint: Use tl.store().
...
tl.store(output_ptr, final_sum)

# --- END OF STUDENT IMPLEMENTATION ---

Expand Down
19 changes: 15 additions & 4 deletions problem_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def _flash_attention_forward_kernel(
q_offsets = (q_block_idx * BLOCK_M + tl.arange(0, BLOCK_M))
q_ptrs = Q_ptr + batch_idx * q_stride_b + head_idx * q_stride_h + \
(q_offsets[:, None] * q_stride_s + tl.arange(0, HEAD_DIM)[None, :])
q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0)

# a block of a [BATCH_NUM, HEAD_NUM, q_offsets[:, None] * q_stride_s :, :] => contains the unprocessed tokens, masking here is to only process the limit number of tokens
q_block = tl.load(q_ptrs, mask=q_offsets[:, None] < SEQ_LEN, other=0.0)

# PyTorch softmax is exp(x), Triton is exp2(x * log2(e)), log2(e) is approx 1.44269504
qk_scale = softmax_scale * 1.44269504

Expand All @@ -61,15 +62,25 @@ def _flash_attention_forward_kernel(
(k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)

# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# Implement the online softmax update logic.
# --- STUDENT IMPLEMENTATION REQUIRED HERE
# 1. Find the new running maximum (`m_new`).
m_ij = tl.max(s_ij, axis=1)
m_new = tl.maximum(m_i, m_ij)
# 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`).
# Use tl.exp2 for Triton's online softmax formulation.
scale_factor = tl.exp2(m_i - m_new)
acc *= scale_factor[:, None]
l_i *= scale_factor
# 3. Compute the attention probabilities for the current tile (`p_ij`).
p_ij = tl.exp2(s_ij - m_new[:, None])
# 4. Update the accumulator `acc` using `p_ij` and `v_block`.
acc += tl.dot(p_ij.to(v_block.type), v_block)
# 5. Update the denominator `l_i`.
l_i += tl.sum(p_ij, axis=1)
# 6. Update the running maximum `m_i` for the next iteration.
m_i = m_new
pass

# --- END OF STUDENT IMPLEMENTATION ---


Expand Down
43 changes: 43 additions & 0 deletions problem_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,28 @@ def _flash_attention_forward_causal_kernel(
# Implement the logic for the off-diagonal blocks.
# This is very similar to the non-causal version from Problem 3.
# 1. Load the K and V blocks for the current iteration.
k_offsets = start_n + tl.arange(0, BLOCK_N)
k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \
(k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
k_block = tl.load(k_ptrs, mask = k_offsets[None, :] < SEQ_LEN , other = 0.0)

v_offsets = start_n + tl.arange(0, BLOCK_N)
v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \
(v_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask = v_offsets[:, None] < SEQ_LEN, other = 0.0)

# 2. Compute the attention scores (S_ij).
S_ij = tl.dot(q_block, k_block)
## mask = (start_n + k_offsets[:, None]) <= q_offsets[None, :] # don't quite get this part???
S_ij *= qk_scale ##+ tl.where(mask, 0, -1.0e6)
# 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc).
m_ij = tl.maximum(m_i, tl.max(S_ij, 1))
scale_factor = tl.exp2(m_i - m_ij)
P_ij = tl.exp2(S_ij - m_ij[:, None])
l_i = l_i * scale_factor + tl.sum(P_ij, 1)
acc = acc * scale_factor[:, None] + tl.dot(P_ij.to(v_block.type), v_block)

m_i = m_ij
pass
# --- END OF STUDENT IMPLEMENTATION ---

Expand All @@ -64,6 +84,29 @@ def _flash_attention_forward_causal_kernel(
for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N):
# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# Implement the logic for the diagonal blocks, apply the causal mask to S_ij.
k_offsets = start_n + tl.arange(0, BLOCK_N)
k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \
(k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
k_block = tl.load(k_ptrs, mask = k_offsets[None, :] < SEQ_LEN , other = 0.0)

v_offsets = start_n + tl.arange(0, BLOCK_N)
v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \
(v_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask = v_offsets[:, None] < SEQ_LEN, other = 0.0)

# 2. Compute the attention scores (S_ij).
S_ij = tl.dot(q_block, k_block)
mask = (k_offsets[None, :]) <= q_offsets[:, None] # don't quite get this part???
S_ij = tl.where(mask, S_ij, -float('inf'))
S_ij *= qk_scale
# 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc).
m_ij = tl.maximum(m_i, tl.max(S_ij, 1))
scale_factor = tl.exp2(m_i - m_ij) # * 1.44269504
P_ij = tl.exp2(S_ij - m_ij[:, None])
l_i = l_i * scale_factor + tl.sum(P_ij, 1)
acc = acc * scale_factor[:, None] + tl.dot(P_ij.to(v_block.type), v_block)

m_i = m_ij
pass
# --- END OF STUDENT IMPLEMENTATION ---

Expand Down
48 changes: 44 additions & 4 deletions problem_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def _flash_attention_forward_gqa_kernel(
# Your goal is to map the current query head (q_head_idx) to its corresponding shared key/value head (kv_head_idx).
# 1. Calculate how many query heads are in each group.
# 2. Use integer division to find the correct kv_head_idx.

kv_head_idx = 0 # Placeholder: Replace with your calculation
group_size = N_Q_HEADS // N_KV_HEADS
kv_head_idx = q_head_idx // group_size # Placeholder: Replace with your calculation
# --- END OF STUDENT IMPLEMENTATION ---


Expand All @@ -57,8 +57,28 @@ def _flash_attention_forward_gqa_kernel(
for start_n in range(0, q_block_idx * BLOCK_M, BLOCK_N):
# --- STUDENT IMPLEMENTATION REQUIRED HERE (Part 2) ---
# 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`.
k_offsets = start_n + tl.arange(0, BLOCK_N)
k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \
(k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
k_block = tl.load(k_ptrs, mask = k_offsets[None, :] < SEQ_LEN , other = 0.0)

v_offsets = start_n + tl.arange(0, BLOCK_N)
v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \
(v_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask = v_offsets[:, None] < SEQ_LEN, other = 0.0)

# 2. Reuse your working implementation for the online softmax update
# from your solution to Problem 4.
S_ij = tl.dot(q_block, k_block)
## mask = (start_n + k_offsets[:, None]) <= q_offsets[None, :] # don't quite get this part???
S_ij *= qk_scale ##+ tl.where(mask, 0, -1.0e6)
# 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc).
m_ij = tl.maximum(m_i, tl.max(S_ij, 1))
scale_factor = tl.exp2(m_i - m_ij)
P_ij = tl.exp2(S_ij - m_ij[:, None])
l_i = l_i * scale_factor + tl.sum(P_ij, 1)
acc = acc * scale_factor[:, None] + tl.dot(P_ij.to(v_block.type), v_block)

m_i = m_ij
pass
# --- END OF STUDENT IMPLEMENTATION ---

Expand All @@ -67,8 +87,28 @@ def _flash_attention_forward_gqa_kernel(
for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N):
# --- STUDENT IMPLEMENTATION REQUIRED HERE (Part 3) ---
# 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`.
k_offsets = start_n + tl.arange(0, BLOCK_N)
k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \
(k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
k_block = tl.load(k_ptrs, mask = k_offsets[None, :] < SEQ_LEN , other = 0.0)

v_offsets = start_n + tl.arange(0, BLOCK_N)
v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \
(v_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask = v_offsets[:, None] < SEQ_LEN, other = 0.0)
# 2. Reuse your working implementation for the masked online softmax
# update from your solution to Problem 4.
S_ij = tl.dot(q_block, k_block)
mask = (k_offsets[None, :]) <= q_offsets[:, None] # don't quite get this part???
S_ij = tl.where(mask, S_ij, -float('inf'))
S_ij *= qk_scale
# 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc).
m_ij = tl.maximum(m_i, tl.max(S_ij, 1))
scale_factor = tl.exp2(m_i - m_ij) # * 1.44269504
P_ij = tl.exp2(S_ij - m_ij[:, None])
l_i = l_i * scale_factor + tl.sum(P_ij, 1)
acc = acc * scale_factor[:, None] + tl.dot(P_ij.to(v_block.type), v_block)

m_i = m_ij
pass
# --- END OF STUDENT IMPLEMENTATION ---

Expand Down
Loading