File tree Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Original file line number Diff line number Diff line change 1717WITH_TMA = os .getenv ("WITH_TMA" )
1818HAS_EXPLICIT_WS = os .getenv ("ENABLE_EXPLICIT_WS" )
1919SUPPORT_GLUON = os .getenv ("WITH_GLUON" )
20+ WITH_MAXNREG = os .getenv ("WITH_MAXNREG" )
2021
2122
2223class TmaAutoTuneHelper :
Original file line number Diff line number Diff line change 1717import triton .language as tl
1818from triton .tools .tensor_descriptor import TensorDescriptor
1919
20+ from .attention_utils import WITH_MAXNREG
21+
2022from .blackwell_attention_utils import (
2123 is_blackwell ,
2224 is_cuda ,
@@ -204,12 +206,13 @@ def _host_descriptor_pre_hook(nargs):
204206 num_stages = s ,
205207 num_warps = w ,
206208 pre_hook = _host_descriptor_pre_hook ,
209+ # ir_override=f"override/_attn_fwd_persist.ttgir"
207210 )
208211 for BM in [256 ]
209212 for BN in [128 ]
210213 for s in NUM_STAGES_OPTIONS
211214 for w in [4 ]
212- for subtile in [True , False ]
215+ for subtile in [False ] # disable subtiling for now
213216 ]
214217
215218
@@ -267,7 +270,7 @@ def _attn_fwd_tma_dp(
267270 off_z = off_hz // H
268271 off_h = off_hz % H
269272
270- offset_y = off_z + off_h * N_CTX
273+ offset_y = off_z * ( N_CTX * H ) + off_h * N_CTX
271274 qo_offset_y = offset_y + start_m * BLOCK_M
272275 # initialize offsets
273276 offs_m0 = start_m * BLOCK_M + tl .arange (0 , BLOCK_M // 2 )
@@ -569,7 +572,7 @@ def grid_debug(META):
569572
570573 ctx .grid = grid
571574 persistent = baseVariant == "persistent" or baseVariant == "ws_persistent"
572- if is_blackwell () and warp_specialize :
575+ if WITH_MAXNREG and is_blackwell () and warp_specialize :
573576 if HEAD_DIM_K == 128 and (
574577 q .dtype == torch .float16 or q .dtype == torch .bfloat16
575578 ):
You can’t perform that action at this time.
0 commit comments