diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 63a2f7e211..34c5d4e351 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -259,8 +259,10 @@ def forward( scale_k_count = ( k_dim + scale_gran_k - 1 ) // scale_gran_k # k dimension - scale_a_expanded = scale_a.view(1, 1).expand( - scale_m_count, scale_k_count + scale_a_expanded = ( + scale_a.view(1, 1) + .expand(scale_m_count, scale_k_count) + .contiguous() ) else: scale_a_expanded = scale_a @@ -273,8 +275,10 @@ def forward( scale_k_count = ( k_dim + scale_gran_k - 1 ) // scale_gran_k # k dimension - scale_b_expanded = scale_b.view(1, 1).expand( - scale_n_count, scale_k_count + scale_b_expanded = ( + scale_b.view(1, 1) + .expand(scale_n_count, scale_k_count) + .contiguous() ) else: scale_b_expanded = scale_b