Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions problem_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def forward(ctx, Q, K, V, is_causal=False):
N_K_tiles = math.ceil(N_K / K_TILE_SIZE)

# Initialize final output tensors
O_final = torch.zeros_like(Q, dtype=Q.dtype)
O_final = torch.zeros_like(Q, dtype=Q.dtype).to(torch.float32)
L_final = torch.zeros((B, H, N_Q), device=Q.device, dtype=torch.float32)

scale = 1.0 / math.sqrt(D_H)
Expand All @@ -38,10 +38,10 @@ def forward(ctx, Q, K, V, is_causal=False):
for i in range(N_Q_tiles):
q_start = i * Q_TILE_SIZE
q_end = min((i + 1) * Q_TILE_SIZE, N_Q)
Q_tile = Q_bh[q_start:q_end, :]
Q_tile = Q_bh[q_start:q_end, :].to(torch.float32)

# Initialize accumulators for this query tile
o_i = torch.zeros_like(Q_tile, dtype=Q.dtype)
o_i = torch.zeros_like(Q_tile, dtype=Q.dtype).to(torch.float32)
l_i = torch.zeros(q_end - q_start, device=Q.device, dtype=torch.float32)
m_i = torch.full((q_end - q_start,), -float('inf'), device=Q.device, dtype=torch.float32)

Expand All @@ -50,24 +50,38 @@ def forward(ctx, Q, K, V, is_causal=False):
k_start = j * K_TILE_SIZE
k_end = min((j + 1) * K_TILE_SIZE, N_K)

K_tile = K_bh[k_start:k_end, :]
V_tile = V_bh[k_start:k_end, :]
K_tile = K_bh[k_start:k_end, :].to(torch.float32)
V_tile = V_bh[k_start:k_end, :].to(torch.float32)

S_ij = (Q_tile @ K_tile.transpose(-1, -2)) * scale

# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# print("S_ij.shape", S_ij.shape)
# 1. Apply causal masking if is_causal is True.
#
if is_causal:
# Build a causal mask for the current q/k tile based on absolute positions.
# Mask positions where key_index > query_index (i.e., future positions).
q_len = q_end - q_start
k_len = k_end - k_start
q_positions = torch.arange(q_start, q_end, device=Q.device).unsqueeze(1) # (q_len, 1)
k_positions = torch.arange(k_start, k_end, device=Q.device).unsqueeze(0) # (1, k_len)
causal_mask = k_positions > q_positions # (q_len, k_len) True for disallowed attends
S_ij = S_ij.masked_fill(causal_mask, float('-inf'))

# 2. Compute the new running maximum
#
m_new = torch.max(m_i, torch.max(S_ij, dim=1).values)

# 3. Rescale the previous accumulators (o_i, l_i)
#
l_i_rescaled = (l_i*(torch.exp(m_i-m_new)))
mult = torch.exp(m_i - m_new) # shape: (q_len,)
o_i_rescaled = o_i * mult.unsqueeze(-1)
# 4. Compute the probabilities for the current tile, P_tilde_ij = exp(S_ij - m_new).
#
# P_tilde_ij = torch.exp(S_ij - m_new).to(torch.float32)
P_tilde_ij = torch.exp(S_ij - m_new.unsqueeze(1)).to(torch.float32)
# 5. Accumulate the current tile's contribution to the accumulators to update l_i and o_i
#
l_new = l_i_rescaled + (P_tilde_ij.sum(dim=1))
o_i = o_i_rescaled + (P_tilde_ij @ V_tile)
# 6. Update the running max for the next iteration

m_i = m_new
l_i = l_new
# --- END OF STUDENT IMPLEMENTATION ---

# After iterating through all key tiles, normalize the output
Expand All @@ -82,7 +96,7 @@ def forward(ctx, Q, K, V, is_causal=False):
L_final[b, h, q_start:q_end] = L_tile

O_final = O_final.to(Q.dtype)

# adf
ctx.save_for_backward(Q, K, V, O_final, L_final)
ctx.is_causal = is_causal

Expand Down
24 changes: 12 additions & 12 deletions problem_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,47 @@ def weighted_row_sum_kernel(
"""
# 1. Get the row index for the current program instance.
# Hint: Use tl.program_id(axis=0).
row_idx = ...
row_idx = tl.program_id(axis=0)

# 2. Create a pointer to the start of the current row in the input tensor X.
# Hint: The offset depends on the row index and the number of columns (N_COLS).
row_start_ptr = ...
row_start_ptr = X_ptr + row_idx * N_COLS

# 3. Create a pointer for the output vector Y.
output_ptr = ...
output_ptr = Y_ptr + row_idx

# 4. Initialize an accumulator for the sum of the products for a block.
# This should be a block-sized tensor of zeros.
# Hint: Use tl.zeros with shape (BLOCK_SIZE,) and dtype tl.float32.
accumulator = ...
accumulator = tl.zeros([BLOCK_SIZE], dtype=tl.float32)

# 5. Iterate over the columns of the row in blocks of BLOCK_SIZE.
# Hint: Use a for loop with tl.cdiv(N_COLS, BLOCK_SIZE).
for col_block_start in range(0, ...):
for col_block_start in range(0, tl.cdiv(N_COLS, BLOCK_SIZE)):
# - Calculate the offsets for the current block of columns.
# Hint: Start from the block's beginning and add tl.arange(0, BLOCK_SIZE).
col_offsets = ...
col_offsets = col_block_start * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

# - Create a mask to prevent out-of-bounds memory access for the last block.
# Hint: Compare col_offsets with N_COLS.
mask = ...
mask = col_offsets < N_COLS

# - Load a block of data from X and W safely using the mask.
# Hint: Use tl.load with the appropriate pointers, offsets, and mask.
# Use `other=0.0` to handle out-of-bounds elements.
x_chunk = tl.load(...)
w_chunk = tl.load(...)
x_chunk = tl.load(row_start_ptr + col_offsets, mask=mask, other=0.0)
w_chunk = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)

# - Compute the element-wise product and add it to the accumulator.
accumulator += ...
accumulator += x_chunk * w_chunk

# 6. Reduce the block-sized accumulator to a single scalar value after the loop.
# Hint: Use tl.sum().
final_sum = ...
final_sum = tl.sum(accumulator)

# 7. Store the final accumulated sum to the output tensor Y.
# Hint: Use tl.store().
...
tl.store(output_ptr, final_sum)

# --- END OF STUDENT IMPLEMENTATION ---

Expand Down
18 changes: 16 additions & 2 deletions problem_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,28 @@ def _flash_attention_forward_kernel(
v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)

# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# Implement the online softmax update logic.
# Implement the online softmax update logic (streaming, numerically stable).
# Ensure consistent fp32 dtype for reductions and dot products.
# q_block = tl.cast(q_block, tl.float32)
# k_block = tl.cast(k_block, tl.float32)
# v_block = tl.cast(v_block, tl.float32)
# s_ij = tl.cast(s_ij, tl.float32)

# 1. Find the new running maximum (`m_new`).
m_new = tl.maximum(m_i, tl.max(s_ij, axis=1))
# 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`).
l_i_rescaled = (l_i*(tl.exp2(m_i-m_new)))
mult = tl.exp2(m_i - m_new) # shape: (q_len,)
acc_rescaled = acc * mult[:, None]
# 3. Compute the attention probabilities for the current tile (`p_ij`).
P_tilde_ij = tl.exp2(s_ij - m_new[:, None])
# 4. Update the accumulator `acc` using `p_ij` and `v_block`.
l_new = l_i_rescaled + tl.sum(P_tilde_ij, axis=1)
acc = acc_rescaled + tl.dot(P_tilde_ij ,tl.cast(v_block, tl.float32))
# 5. Update the denominator `l_i`.
l_i = l_new
# 6. Update the running maximum `m_i` for the next iteration.
pass
m_i = m_new
# --- END OF STUDENT IMPLEMENTATION ---


Expand Down
79 changes: 77 additions & 2 deletions problem_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,44 @@ def _flash_attention_forward_causal_kernel(
# 1. Load the K and V blocks for the current iteration.
# 2. Compute the attention scores (S_ij).
# 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc).
pass
k_offsets = start_n + tl.arange(0, BLOCK_N)
k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \
(k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0)

# Compute attention scores S_ij = Q_i * K_j^T
s_ij = tl.dot(q_block, k_block)
s_ij *= qk_scale

# Load V_j
v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \
(k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)

# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# Implement the online softmax update logic (streaming, numerically stable).
# Ensure consistent fp32 dtype for reductions and dot products.
# q_block = tl.cast(q_block, tl.float32)
# k_block = tl.cast(k_block, tl.float32)
# v_block = tl.cast(v_block, tl.float32)
# s_ij = tl.cast(s_ij, tl.float32)

# 1. Find the new running maximum (`m_new`).
m_new = tl.maximum(m_i, tl.max(s_ij, axis=1))
# 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`).
l_i_rescaled = (l_i*(tl.exp2(m_i-m_new)))
mult = tl.exp2(m_i - m_new) # shape: (q_len,)
acc_rescaled = acc * mult[:, None]
# 3. Compute the attention probabilities for the current tile (`p_ij`).
P_tilde_ij = tl.exp2(s_ij - m_new[:, None])
# 4. Update the accumulator `acc` using `p_ij` and `v_block`.
l_new = l_i_rescaled + tl.sum(P_tilde_ij, axis=1)
acc = acc_rescaled + tl.dot(P_tilde_ij ,tl.cast(v_block, tl.float32))
# 5. Update the denominator `l_i`.
l_i = l_new
# 6. Update the running maximum `m_i` for the next iteration.
m_i = m_new
# --- END OF STUDENT IMPLEMENTATION ---
# --- END OF STUDENT IMPLEMENTATION ---


Expand All @@ -64,7 +101,45 @@ def _flash_attention_forward_causal_kernel(
for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N):
# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# Implement the logic for the diagonal blocks, apply the causal mask to S_ij.
pass
k_offsets = start_n + tl.arange(0, BLOCK_N)
k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \
(k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0)

# Compute attention scores S_ij = Q_i * K_j^T
s_ij = tl.dot(q_block, k_block)
s_ij *= qk_scale
mask = k_offsets[None, :] <= q_offsets[:, None]
s_ij = tl.where(mask, s_ij, -float('inf'))
# Load V_j
v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \
(k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)

# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# Implement the online softmax update logic (streaming, numerically stable).
# Ensure consistent fp32 dtype for reductions and dot products.
# q_block = tl.cast(q_block, tl.float32)
# k_block = tl.cast(k_block, tl.float32)
# v_block = tl.cast(v_block, tl.float32)
# s_ij = tl.cast(s_ij, tl.float32)

# 1. Find the new running maximum (`m_new`).
m_new = tl.maximum(m_i, tl.max(s_ij, axis=1))
# 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`).
l_i_rescaled = (l_i*(tl.exp2(m_i-m_new)))
mult = tl.exp2(m_i - m_new) # shape: (q_len,)
acc_rescaled = acc * mult[:, None]
# 3. Compute the attention probabilities for the current tile (`p_ij`).
P_tilde_ij = tl.exp2(s_ij - m_new[:, None])
# 4. Update the accumulator `acc` using `p_ij` and `v_block`.
l_new = l_i_rescaled + tl.sum(P_tilde_ij, axis=1)
acc = acc_rescaled + tl.dot(P_tilde_ij ,tl.cast(v_block, tl.float32))
# 5. Update the denominator `l_i`.
l_i = l_new
# 6. Update the running maximum `m_i` for the next iteration.
m_i = m_new
# --- END OF STUDENT IMPLEMENTATION ---
# --- END OF STUDENT IMPLEMENTATION ---


Expand Down
91 changes: 86 additions & 5 deletions problem_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def _flash_attention_forward_gqa_kernel(
# --- STUDENT IMPLEMENTATION REQUIRED HERE (Part 1) ---
# Your goal is to map the current query head (q_head_idx) to its corresponding shared key/value head (kv_head_idx).
# 1. Calculate how many query heads are in each group.
n_q_head_per_group = N_Q_HEADS // N_KV_HEADS
# 2. Use integer division to find the correct kv_head_idx.

kv_head_idx = 0 # Placeholder: Replace with your calculation
kv_head_idx = q_head_idx // n_q_head_per_group # Placeholder: Replace with your calculation
# --- END OF STUDENT IMPLEMENTATION ---

#(dhruv) this solution passes all the testa cses but ideally it ahould have edge case handling for when nq_heads is not completely divisible by nkv_heads

# 2. Initialize accumulators in SRAM.
m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32)
Expand All @@ -59,7 +59,50 @@ def _flash_attention_forward_gqa_kernel(
# 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`.
# 2. Reuse your working implementation for the online softmax update
# from your solution to Problem 4.
pass
# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# Implement the logic for the off-diagonal blocks.
# This is very similar to the non-causal version from Problem 3.
# 1. Load the K and V blocks for the current iteration.
# 2. Compute the attention scores (S_ij).
# 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc).
k_offsets = start_n + tl.arange(0, BLOCK_N)
k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \
(k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0)

# Compute attention scores S_ij = Q_i * K_j^T
s_ij = tl.dot(q_block, k_block)
s_ij *= qk_scale

# Load V_j
v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \
(k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)

# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# Implement the online softmax update logic (streaming, numerically stable).
# Ensure consistent fp32 dtype for reductions and dot products.
# q_block = tl.cast(q_block, tl.float32)
# k_block = tl.cast(k_block, tl.float32)
# v_block = tl.cast(v_block, tl.float32)
# s_ij = tl.cast(s_ij, tl.float32)

# 1. Find the new running maximum (`m_new`).
m_new = tl.maximum(m_i, tl.max(s_ij, axis=1))
# 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`).
l_i_rescaled = (l_i*(tl.exp2(m_i-m_new)))
mult = tl.exp2(m_i - m_new) # shape: (q_len,)
acc_rescaled = acc * mult[:, None]
# 3. Compute the attention probabilities for the current tile (`p_ij`).
P_tilde_ij = tl.exp2(s_ij - m_new[:, None])
# 4. Update the accumulator `acc` using `p_ij` and `v_block`.
l_new = l_i_rescaled + tl.sum(P_tilde_ij, axis=1)
acc = acc_rescaled + tl.dot(P_tilde_ij ,tl.cast(v_block, tl.float32))
# 5. Update the denominator `l_i`.
l_i = l_new
# 6. Update the running maximum `m_i` for the next iteration.
m_i = m_new
# --- END OF STUDENT IMPLEMENTATION ---
# --- END OF STUDENT IMPLEMENTATION ---

# --- Phase 2: Diagonal Blocks ---
Expand All @@ -69,7 +112,45 @@ def _flash_attention_forward_gqa_kernel(
# 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`.
# 2. Reuse your working implementation for the masked online softmax
# update from your solution to Problem 4.
pass
k_offsets = start_n + tl.arange(0, BLOCK_N)
k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \
(k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0)

# Compute attention scores S_ij = Q_i * K_j^T
s_ij = tl.dot(q_block, k_block)
s_ij *= qk_scale
mask = k_offsets[None, :] <= q_offsets[:, None]
s_ij = tl.where(mask, s_ij, -float('inf'))
# Load V_j
v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \
(k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)

# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# Implement the online softmax update logic (streaming, numerically stable).
# Ensure consistent fp32 dtype for reductions and dot products.
# q_block = tl.cast(q_block, tl.float32)
# k_block = tl.cast(k_block, tl.float32)
# v_block = tl.cast(v_block, tl.float32)
# s_ij = tl.cast(s_ij, tl.float32)

# 1. Find the new running maximum (`m_new`).
m_new = tl.maximum(m_i, tl.max(s_ij, axis=1))
# 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`).
l_i_rescaled = (l_i*(tl.exp2(m_i-m_new)))
mult = tl.exp2(m_i - m_new) # shape: (q_len,)
acc_rescaled = acc * mult[:, None]
# 3. Compute the attention probabilities for the current tile (`p_ij`).
P_tilde_ij = tl.exp2(s_ij - m_new[:, None])
# 4. Update the accumulator `acc` using `p_ij` and `v_block`.
l_new = l_i_rescaled + tl.sum(P_tilde_ij, axis=1)
acc = acc_rescaled + tl.dot(P_tilde_ij ,tl.cast(v_block, tl.float32))
# 5. Update the denominator `l_i`.
l_i = l_new
# 6. Update the running maximum `m_i` for the next iteration.
m_i = m_new
# --- END OF STUDENT IMPLEMENTATION ---
# --- END OF STUDENT IMPLEMENTATION ---

# 4. Normalize and write the final output block.
Expand Down
Loading