Skip to content

Commit ce05311

Browse files
mshelegosys-cmllvm
authored andcommitted
Kernel argument attributes fix for opaque pointers
In case when frontend uses opaque pointers, types used in attributes like byval are lost and substituted by i8. Try to restore the original type by analyzing the memory instructions that uses the argument.
1 parent f8055b7 commit ce05311

File tree

3 files changed

+56
-9
lines changed

3 files changed

+56
-9
lines changed

GenXIntrinsics/lib/GenXIntrinsics/AdaptorsCommon.cpp

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ SPDX-License-Identifier: MIT
88

99
#include "AdaptorsCommon.h"
1010

11+
#include "llvm/ADT/SmallPtrSet.h"
1112
#include "llvm/IR/Function.h"
1213
#include "llvm/IR/Instructions.h"
1314

@@ -30,6 +31,45 @@ static void legalizeAttribute(Argument &Arg, Type *NewType,
3031

3132
#endif
3233

34+
Type *getPtrElemType(Value *V) {
35+
#if VC_INTR_LLVM_VERSION_MAJOR < 14
36+
return VCINTR::Type::getNonOpaquePtrEltTy(V->getType());
37+
#else // VC_INTR_LLVM_VERSION_MAJOR < 14
38+
#if VC_INTR_LLVM_VERSION_MAJOR < 17
39+
auto *PtrTy = cast<PointerType>(V->getType());
40+
if (!PtrTy->isOpaque())
41+
return VCINTR::Type::getNonOpaquePtrEltTy(PtrTy);
42+
#endif // VC_INTR_LLVM_VERSION_MAJOR < 17
43+
SmallPtrSet<Type *, 2> ElemTys;
44+
SmallVector<Value *, 4> Stack;
45+
Stack.push_back(V);
46+
while (!Stack.empty()) {
47+
auto* Current = Stack.back();
48+
Stack.pop_back();
49+
for (auto *U : Current->users()) {
50+
if (ElemTys.size() > 1)
51+
return nullptr;
52+
auto *I = dyn_cast<Instruction>(U);
53+
if (!I)
54+
continue;
55+
if (auto *LI = dyn_cast<LoadInst>(I)) {
56+
if (Current == LI->getPointerOperand())
57+
ElemTys.insert(LI->getType());
58+
} else if (auto *SI = dyn_cast<StoreInst>(I)) {
59+
if (Current == SI->getPointerOperand())
60+
ElemTys.insert(SI->getValueOperand()->getType());
61+
} else if (auto *GEPI = dyn_cast<GetElementPtrInst>(I)) {
62+
if (Current == GEPI->getPointerOperand())
63+
ElemTys.insert(GEPI->getSourceElementType());
64+
} else if (isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I)) {
65+
Stack.push_back(I);
66+
}
67+
}
68+
}
69+
return ElemTys.empty() ? nullptr : *ElemTys.begin();
70+
#endif // VC_INTR_LLVM_VERSION_MAJOR < 14
71+
}
72+
3373
void legalizeParamAttributes(Function *F) {
3474
assert(F && "Valid function ptr must be passed");
3575

@@ -39,15 +79,12 @@ void legalizeParamAttributes(Function *F) {
3979
if (!PTy)
4080
continue;
4181

82+
auto *ElemType = getPtrElemType(&Arg);
4283
#if VC_INTR_LLVM_VERSION_MAJOR >= 13
43-
#if VC_INTR_LLVM_VERSION_MAJOR < 17
44-
if (PTy->isOpaque())
45-
#endif // VC_INTR_LLVM_VERSION_MAJOR < 18
84+
if (!ElemType)
4685
continue;
4786
#endif // VC_INTR_LLVM_VERSION_MAJOR >= 13
4887

49-
auto *ElemType = VCINTR::Type::getNonOpaquePtrEltTy(PTy);
50-
5188
legalizeAttribute(Arg, ElemType, Attribute::ByVal);
5289

5390
#if VC_INTR_LLVM_VERSION_MAJOR >= 11

GenXIntrinsics/lib/GenXIntrinsics/GenXSPIRVReaderAdaptor.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,7 @@ transformKernelSignature(Function &F, const std::vector<SPIRVArgDesc> &Descs) {
481481
return PointerType::get(Type::getInt8Ty(Ctx), AddrSpace);
482482
}
483483
#endif
484-
if (!VCINTR::Type::isOpaquePointerTy(ArgTy) &&
485-
Arg.hasByValAttr())
484+
if (Arg.hasByValAttr())
486485
return OrigTy;
487486

488487
return ArgTy;
@@ -523,8 +522,6 @@ transformKernelSignature(Function &F, const std::vector<SPIRVArgDesc> &Descs) {
523522
NewF->addParamAttr(i, Attr);
524523
}
525524

526-
legalizeParamAttributes(NewF);
527-
528525
return NewF;
529526
}
530527

@@ -606,6 +603,8 @@ static void rewriteKernelArguments(Function &F) {
606603
}
607604
}
608605

606+
legalizeParamAttributes(NewF);
607+
609608
F.mutateType(NewF->getType());
610609
F.replaceAllUsesWith(NewF);
611610
F.eraseFromParent();

GenXIntrinsics/test/Adaptors/opaque_ptrs/args_attributes_transform_reader.ll

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,15 @@ define spir_kernel void @test(ptr addrspace(1) byval(%foo) %arg) #0 {
2121
ret void
2222
}
2323

24+
; CHECK: define dllexport spir_kernel void @test_restore(
25+
; CHECK-SAME: ptr byval(%foo)
26+
; CHECK-SAME: [[ARG:%[^)]+]])
27+
define spir_kernel void @test_restore(ptr addrspace(1) byval(i8) %arg) #0 {
28+
%conv = call ptr @llvm.genx.address.convert.p0foo.p1(ptr addrspace(1) %arg)
29+
%gep = getelementptr %foo, ptr %conv, i64 0, i32 0
30+
ret void
31+
}
32+
33+
declare ptr @llvm.genx.address.convert.p0foo.p1(ptr addrspace(1)) #0
34+
2435
attributes #0 = { "VCFunction" }

0 commit comments

Comments
 (0)