Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions problem_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,28 @@ def forward(ctx, Q, K, V, is_causal=False):
S_ij = (Q_tile @ K_tile.transpose(-1, -2)) * scale

# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
S_ij = S_ij.to(torch.float32)
V_tile = V_tile.to(torch.float32)
# 1. Apply causal masking if is_causal is True.
#
if is_causal:
q_idx = torch.arange(q_start, q_end, device=Q.device)[:, None]
k_idx = torch.arange(k_start, k_end, device=Q.device)[None, :]
mask = k_idx > q_idx # (q_len, k_len)
S_ij = S_ij.masked_fill(mask, torch.finfo(torch.float32).min)
# 2. Compute the new running maximum
#
m_ij = torch.max(S_ij, dim=-1).values #(128,)
m_new = torch.maximum(m_i, m_ij) #m_i is (128,) -> m_new is (128,)
# 3. Rescale the previous accumulators (o_i, l_i)
#
scale_factor = torch.exp(m_i - m_new) #(128,)
o_i = o_i * scale_factor.unsqueeze(-1) # o_i is (128,16) -> (128,16)
l_i = l_i * scale_factor # l_i is (128,) -> (128,)
# 4. Compute the probabilities for the current tile, P_tilde_ij = exp(S_ij - m_new).
#
P_tilde_ij = torch.exp(S_ij - m_new.unsqueeze(-1)) #(128,128)
# 5. Accumulate the current tile's contribution to the accumulators to update l_i and o_i
#
o_i = o_i + (P_tilde_ij @ V_tile) #V_tile is (128,16)
l_i = l_i + P_tilde_ij.sum(dim=-1) #P_tilde_ij.sum(dim=-1) -> (128,) sum over cols (row-wise)
# 6. Update the running max for the next iteration

m_i = m_new
# --- END OF STUDENT IMPLEMENTATION ---

# After iterating through all key tiles, normalize the output
Expand Down
31 changes: 16 additions & 15 deletions problem_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,47 @@ def weighted_row_sum_kernel(
"""
# 1. Get the row index for the current program instance.
# Hint: Use tl.program_id(axis=0).
row_idx = ...
row_idx = tl.program_id(axis=0)

# 2. Create a pointer to the start of the current row in the input tensor X.
# Hint: The offset depends on the row index and the number of columns (N_COLS).
row_start_ptr = ...
row_start_ptr = X_ptr + row_idx * N_COLS

# 3. Create a pointer for the output vector Y.
output_ptr = ...
output_ptr = Y_ptr + row_idx

# 4. Initialize an accumulator for the sum of the products for a block.
# This should be a block-sized tensor of zeros.
# Hint: Use tl.zeros with shape (BLOCK_SIZE,) and dtype tl.float32.
accumulator = ...
accumulator = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)

# 5. Iterate over the columns of the row in blocks of BLOCK_SIZE.
# Hint: Use a for loop with tl.cdiv(N_COLS, BLOCK_SIZE).
for col_block_start in range(0, ...):
for col_block_start in range(0, tl.cdiv(N_COLS, BLOCK_SIZE)):
# - Calculate the offsets for the current block of columns.
# Hint: Start from the block's beginning and add tl.arange(0, BLOCK_SIZE).
col_offsets = ...
col_offsets = col_block_start * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

# - Create a mask to prevent out-of-bounds memory access for the last block.
# Hint: Compare col_offsets with N_COLS.
mask = ...
mask = col_offsets < N_COLS

# - Load a block of data from X and W safely using the mask.
# Hint: Use tl.load with the appropriate pointers, offsets, and mask.
# Use `other=0.0` to handle out-of-bounds elements.
x_chunk = tl.load(...)
w_chunk = tl.load(...)
x_chunk = tl.load(row_start_ptr + col_offsets, mask=mask, other=0.0) #if mask=False, load 0.0
w_chunk = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)

# - Compute the element-wise product and add it to the accumulator.
accumulator += ...
accumulator += x_chunk * w_chunk

# 6. Reduce the block-sized accumulator to a single scalar value after the loop.
# Hint: Use tl.sum().
final_sum = ...
final_sum = tl.sum(accumulator, axis=0)

# 7. Store the final accumulated sum to the output tensor Y.
# Hint: Use tl.store().
...
tl.store(output_ptr, final_sum) #store tensor of data into mem locs def by output_ptr

# --- END OF STUDENT IMPLEMENTATION ---

Expand Down Expand Up @@ -94,4 +94,5 @@ def torch_weighted_row_sum(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""
Reference implementation using pure PyTorch.
"""
return (x * w).sum(dim=1)
y = (x * w).sum(dim=1)
return y.to(x.dtype)
12 changes: 11 additions & 1 deletion problem_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,23 @@ def _flash_attention_forward_kernel(

# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# Implement the online softmax update logic.
s_ij = s_ij.to(tl.float32)
v_block = v_block.to(tl.float32)
# 1. Find the new running maximum (`m_new`).
m_ij = tl.max(s_ij, axis=1) #(BLOCK_M,)
m_new = tl.maximum(m_i, m_ij) #(BLOCK_M,)
# 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`).
scale_factor = tl.exp2(m_i - m_new) #(BLOCK_M,)
acc = acc * scale_factor[:, None] #(BLOCK_M, HEAD_DIM)
l_i = l_i * scale_factor #(BLOCK_M,)
# 3. Compute the attention probabilities for the current tile (`p_ij`).
p_ij = tl.exp2(s_ij - m_new[:, None]) #(BLOCK_M, BLOCK_N)
# 4. Update the accumulator `acc` using `p_ij` and `v_block`.
acc = acc + tl.dot(p_ij, v_block) #(BLOCK_M, HEAD_DIM)
# 5. Update the denominator `l_i`.
l_i = l_i + tl.sum(p_ij, axis=1) #(BLOCK_M,)
# 6. Update the running maximum `m_i` for the next iteration.
pass
m_i = m_new
# --- END OF STUDENT IMPLEMENTATION ---


Expand Down
72 changes: 70 additions & 2 deletions problem_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,40 @@ def _flash_attention_forward_causal_kernel(
# Implement the logic for the off-diagonal blocks.
# This is very similar to the non-causal version from Problem 3.
# 1. Load the K and V blocks for the current iteration.
# Load K_j
k_offsets = start_n + tl.arange(0, BLOCK_N)
k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \
(k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0)

# Load V_j
v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \
(k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)

# 2. Compute the attention scores (S_ij).
s_ij = tl.dot(q_block, k_block)
s_ij *= qk_scale

s_ij = s_ij.to(tl.float32)
v_block = v_block.to(tl.float32)

# 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc).
pass
# 1. Find the new running maximum (`m_new`).
m_ij = tl.max(s_ij, axis=1) #(BLOCK_M,)
m_new = tl.maximum(m_i, m_ij) #(BLOCK_M,)
# 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`).
scale_factor = tl.exp2(m_i - m_new) #(BLOCK_M,)
acc = acc * scale_factor[:, None] #(BLOCK_M, HEAD_DIM)
l_i = l_i * scale_factor #(BLOCK_M,)
# 3. Compute the attention probabilities for the current tile (`p_ij`).
p_ij = tl.exp2(s_ij - m_new[:, None]) #(BLOCK_M, BLOCK_N)
# 4. Update the accumulator `acc` using `p_ij` and `v_block`.
acc = acc + tl.dot(p_ij, v_block) #(BLOCK_M, HEAD_DIM)
# 5. Update the denominator `l_i`.
l_i = l_i + tl.sum(p_ij, axis=1) #(BLOCK_M,)
# 6. Update the running maximum `m_i` for the next iteration.
m_i = m_new
# --- END OF STUDENT IMPLEMENTATION ---


Expand All @@ -64,7 +95,44 @@ def _flash_attention_forward_causal_kernel(
for start_n in range(diag_start, (q_block_idx + 1) * BLOCK_M, BLOCK_N):
# --- STUDENT IMPLEMENTATION REQUIRED HERE ---
# Implement the logic for the diagonal blocks, apply the causal mask to S_ij.
pass
# Load K_j
k_offsets = start_n + tl.arange(0, BLOCK_N) # (BLOCK_N,)
k_ptrs = K_ptr + batch_idx * k_stride_b + head_idx * k_stride_h + \
(k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0)

# Load V_j
v_ptrs = V_ptr + batch_idx * v_stride_b + head_idx * v_stride_h + \
(k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)

s_ij = tl.dot(q_block, k_block)
s_ij *= qk_scale

s_ij = s_ij.to(tl.float32)
v_block = v_block.to(tl.float32)

# Build mask
q_idx = q_offsets
k_idx = k_offsets
causal = q_idx[:, None] >= k_idx[None, :] #Lower triangle true
valid = (q_idx[:, None] < SEQ_LEN) & (k_idx[None, :] < SEQ_LEN)
mask = causal & valid

# Apply mask BEFORE tile max so future tokens don't affect m_i
neg_inf = -float("inf")
s_ij = tl.where(mask, s_ij, neg_inf)

# online softmax update
m_ij = tl.max(s_ij, axis=1)
m_new = tl.maximum(m_i, m_ij)
scale_factor = tl.exp2(m_i - m_new)

p_ij = tl.exp2(s_ij - m_new[:, None])

acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block)
l_i = l_i * scale_factor + tl.sum(p_ij, axis=1)
m_i = m_new
# --- END OF STUDENT IMPLEMENTATION ---


Expand Down
73 changes: 69 additions & 4 deletions problem_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def _flash_attention_forward_gqa_kernel(
# --- STUDENT IMPLEMENTATION REQUIRED HERE (Part 1) ---
# Your goal is to map the current query head (q_head_idx) to its corresponding shared key/value head (kv_head_idx).
# 1. Calculate how many query heads are in each group.
group_size = N_Q_HEADS // N_KV_HEADS
# 2. Use integer division to find the correct kv_head_idx.

kv_head_idx = 0 # Placeholder: Replace with your calculation
kv_head_idx = q_head_idx // group_size
# --- END OF STUDENT IMPLEMENTATION ---


Expand All @@ -59,7 +59,39 @@ def _flash_attention_forward_gqa_kernel(
# 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`.
# 2. Reuse your working implementation for the online softmax update
# from your solution to Problem 4.
pass
# Load K_j
k_offsets = start_n + tl.arange(0, BLOCK_N)
k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \
(k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0)

# Load V_j
v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \
(k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)

# 2. Compute the attention scores (S_ij).
s_ij = tl.dot(q_block.to(tl.float32), k_block.to(tl.float32))
s_ij *= qk_scale

v_block = v_block.to(tl.float32)

# 3. Update the online softmax statistics (m_i, l_i) and the accumulator (acc).
# 1. Find the new running maximum (`m_new`).
m_ij = tl.max(s_ij, axis=1) #(BLOCK_M,)
m_new = tl.maximum(m_i, m_ij) #(BLOCK_M,)
# 2. Rescale the existing accumulator (`acc`) and denominator (`l_i`).
scale_factor = tl.exp2(m_i - m_new) #(BLOCK_M,)
acc = acc * scale_factor[:, None] #(BLOCK_M, HEAD_DIM)
l_i = l_i * scale_factor #(BLOCK_M,)
# 3. Compute the attention probabilities for the current tile (`p_ij`).
p_ij = tl.exp2(s_ij - m_new[:, None]) #(BLOCK_M, BLOCK_N)
# 4. Update the accumulator `acc` using `p_ij` and `v_block`.
acc = acc + tl.dot(p_ij, v_block) #(BLOCK_M, HEAD_DIM)
# 5. Update the denominator `l_i`.
l_i = l_i + tl.sum(p_ij, axis=1) #(BLOCK_M,)
# 6. Update the running maximum `m_i` for the next iteration.
m_i = m_new
# --- END OF STUDENT IMPLEMENTATION ---

# --- Phase 2: Diagonal Blocks ---
Expand All @@ -69,7 +101,40 @@ def _flash_attention_forward_gqa_kernel(
# 1. Modify the pointer arithmetic for K and V to use your `kv_head_idx`.
# 2. Reuse your working implementation for the masked online softmax
# update from your solution to Problem 4.
pass
# Load K_j
k_offsets = start_n + tl.arange(0, BLOCK_N) # (BLOCK_N,)
k_ptrs = K_ptr + batch_idx * k_stride_b + kv_head_idx * k_stride_h + \
(k_offsets[None, :] * k_stride_s + tl.arange(0, HEAD_DIM)[:, None])
k_block = tl.load(k_ptrs, mask=k_offsets[None, :] < SEQ_LEN, other=0.0)

# Load V_j
v_ptrs = V_ptr + batch_idx * v_stride_b + kv_head_idx * v_stride_h + \
(k_offsets[:, None] * v_stride_s + tl.arange(0, HEAD_DIM)[None, :])
v_block = tl.load(v_ptrs, mask=k_offsets[:, None] < SEQ_LEN, other=0.0)

s_ij = tl.dot(q_block.to(tl.float32), k_block.to(tl.float32))
s_ij *= qk_scale

v_block = v_block.to(tl.float32)

# build mask
causal = q_offsets[:, None] >= k_offsets[None, :] #Lower triangle true
valid = (q_offsets[:, None] < SEQ_LEN) & (k_offsets[None, :] < SEQ_LEN)
mask = causal & valid

# Apply mask BEFORE tile max so future tokens don't affect m_i
s_ij = tl.where(mask, s_ij, -float("inf"))

# online softmax update
m_ij = tl.max(s_ij, axis=1)
m_new = tl.maximum(m_i, m_ij)
scale_factor = tl.exp2(m_i - m_new)

p_ij = tl.exp2(s_ij - m_new[:, None])

acc = acc * scale_factor[:, None] + tl.dot(p_ij, v_block)
l_i = l_i * scale_factor + tl.sum(p_ij, axis=1)
m_i = m_new
# --- END OF STUDENT IMPLEMENTATION ---

# 4. Normalize and write the final output block.
Expand Down
Loading