Skip to content

Move the tensor descriptor block size check. #7325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op,
getTMABlockShape(encoding, shapePerCTA, /*packedSize=*/false);
auto contigDimSize = blockShape.back();

if (contigDimSize * elemSize < 16) {
return op->emitError("Descriptor block shape must have at least 16 bytes "
"in the last dimension, but got ")
<< contigDimSize << " * " << elemSize << " = "
<< (contigDimSize * elemSize) << " bytes";
}
Comment on lines +256 to +261
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want this to be a front end error as this will be cleaning reported to user and not be a failure in the backend. Also in term of portability ideally we should keep the same restrictions independently of the backend.


llvm::SmallVector<Value> boxDim;
if (fp4Padded && contigDimSize != 128) {
return op->emitError(
Expand Down
30 changes: 7 additions & 23 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
torch_dtype_name,
to_numpy,
)
from triton.runtime.errors import InterpreterError


@contextlib.contextmanager
Expand Down Expand Up @@ -5184,35 +5183,20 @@ def kernel():


@pytest.mark.interpreter
def test_tma_load_block_shape_err(device):
def test_tma_block_shape_err(capfd, device):

@triton.jit
def kernel(ptr):
desc = tl.make_tensor_descriptor(ptr, [128, 128], [128, 1], [1, 2])
desc.load([0, 0])
x = desc.load([0, 0])
x = x + x
desc.store([0, 0], x)

input = torch.empty((128, 128), dtype=torch.int32, device=device)
errc = triton.CompilationError if not is_interpreter() else InterpreterError
with pytest.raises(errc) as e:
with pytest.raises(RuntimeError) as e:
kernel[(1, )](input)

assert "Descriptor block shape must have at least 16 bytes" in str(e.value.__cause__)


@pytest.mark.interpreter
def test_tma_store_block_shape_err(device):

@triton.jit
def kernel(ptr):
desc = tl.make_tensor_descriptor(ptr, [128, 128], [128, 1], [8, 4])
desc.store([0, 0], tl.zeros([8, 4], dtype=tl.int16))

input = torch.empty((128, 128), dtype=torch.int16, device=device)
errc = triton.CompilationError if not is_interpreter() else InterpreterError
with pytest.raises(errc) as e:
kernel[(1, )](input)

assert "Descriptor block shape must have at least 16 bytes" in str(e.value.__cause__)
_, stderr = capfd.readouterr()
assert "Descriptor block shape must have at least 16 bytes" in stderr


def test_trans_reshape(device, with_allocator):
Expand Down
7 changes: 0 additions & 7 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,13 +1858,6 @@ def make_tensor_descriptor(
raise ValueError(f"Expected {ndim} strides but got {len(strides)}")
if len(block_shape) != ndim:
raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}")
assert isinstance(base.dtype, tl.pointer_type)
elem_size = base.dtype.element_ty.primitive_bitwidth // 8
contig_dim_size = tl._unwrap_if_constexpr(block_shape[-1])
if contig_dim_size * elem_size < 16:
raise ValueError(
f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes"
)

strides[-1] = tl._unwrap_if_constexpr(strides[-1])
if strides[-1] != 1:
Expand Down