diff --git a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td index f592ff287a0e3..c1e4b97e96bc8 100644 --- a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td +++ b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td @@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem], "llvm.wasm.ref.is_null.exn">; +def int_wasm_ref_test_func + : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty], + [IntrNoMem]>; + //===----------------------------------------------------------------------===// // Table intrinsics //===----------------------------------------------------------------------===// diff --git a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp index 03d3e8eab35d0..768940f64ee0c 100644 --- a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp @@ -402,7 +402,12 @@ void InstrEmitter::AddOperand(MachineInstrBuilder &MIB, SDValue Op, AddRegisterOperand(MIB, Op, IIOpNum, II, VRBaseMap, IsDebug, IsClone, IsCloned); } else if (ConstantSDNode *C = dyn_cast(Op)) { - MIB.addImm(C->getSExtValue()); + if (C->getAPIntValue().getSignificantBits() <= 64) { + MIB.addImm(C->getSExtValue()); + } else { + MIB.addCImm( + ConstantInt::get(MF->getFunction().getContext(), C->getAPIntValue())); + } } else if (ConstantFPSDNode *F = dyn_cast(Op)) { MIB.addFPImm(F->getConstantFPValue()); } else if (RegisterSDNode *R = dyn_cast(Op)) { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp index ac819cf5c1801..a7991319be8c7 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp @@ -15,12 +15,14 @@ #include "WebAssembly.h" #include "WebAssemblyISelLowering.h" #include "WebAssemblyTargetMachine.h" +#include "WebAssemblyUtilities.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/SelectionDAGISel.h" #include "llvm/CodeGen/WasmEHFuncInfo.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" // To access function attributes. #include "llvm/IR/IntrinsicsWebAssembly.h" +#include "llvm/MC/MCSymbolWasm.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" @@ -118,6 +120,51 @@ static SDValue getTagSymNode(int Tag, SelectionDAG *DAG) { return DAG->getTargetExternalSymbol(SymName, PtrVT); } +static APInt encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL, + SmallVector &Returns, + SmallVector &Params) { + auto toWasmValType = [&DAG, &DL](MVT VT) { + if (VT == MVT::i32) { + return wasm::ValType::I32; + } + if (VT == MVT::i64) { + return wasm::ValType::I64; + } + if (VT == MVT::f32) { + return wasm::ValType::F32; + } + if (VT == MVT::f64) { + return wasm::ValType::F64; + } + LLVM_DEBUG(errs() << "Unhandled type for llvm.wasm.ref.test.func: " << VT + << "\n"); + llvm_unreachable("Unhandled type for llvm.wasm.ref.test.func"); + }; + auto NParams = Params.size(); + auto NReturns = Returns.size(); + auto BitWidth = (NParams + NReturns + 2) * 64; + auto Sig = APInt(BitWidth, 0); + + // Annoying special case: if getSignificantBits() <= 64 then InstrEmitter will + // emit an Imm instead of a CImm. It simplifies WebAssemblyMCInstLower if we + // always emit a CImm. So xor NParams with 0x7ffffff to ensure + // getSignificantBits() > 64 + Sig |= NReturns ^ 0x7ffffff; + for (auto &Return : Returns) { + auto V = toWasmValType(Return); + Sig <<= 64; + Sig |= (int64_t)V; + } + Sig <<= 64; + Sig |= NParams; + for (auto &Param : Params) { + auto V = toWasmValType(Param); + Sig <<= 64; + Sig |= (int64_t)V; + } + return Sig; +} + void WebAssemblyDAGToDAGISel::Select(SDNode *Node) { // If we have a custom node, we already have selected! if (Node->isMachineOpcode()) { @@ -189,6 +236,50 @@ void WebAssemblyDAGToDAGISel::Select(SDNode *Node) { ReplaceNode(Node, TLSAlign); return; } + case Intrinsic::wasm_ref_test_func: { + // First emit the TABLE_GET instruction to convert function pointer ==> + // funcref + MachineFunction &MF = CurDAG->getMachineFunction(); + auto PtrVT = MVT::getIntegerVT(MF.getDataLayout().getPointerSizeInBits()); + MCSymbol *Table = WebAssembly::getOrCreateFunctionTableSymbol( + MF.getContext(), Subtarget); + SDValue TableSym = CurDAG->getMCSymbol(Table, PtrVT); + SDValue FuncRef = SDValue( + CurDAG->getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL, + MVT::funcref, TableSym, Node->getOperand(1)), + 0); + + // Encode the signature information into the type index placeholder. + // This gets decoded and converted into the actual type signature in + // WebAssemblyMCInstLower.cpp. + SmallVector Params; + SmallVector Returns; + + bool IsParam = false; + // Operand 0 is the return register, Operand 1 is the function pointer. + // The remaining operands encode the type of the function we are testing + // for. + for (unsigned I = 2, E = Node->getNumOperands(); I < E; ++I) { + MVT VT = Node->getOperand(I).getValueType().getSimpleVT(); + if (VT == MVT::Untyped) { + IsParam = true; + continue; + } + if (IsParam) { + Params.push_back(VT); + } else { + Returns.push_back(VT); + } + } + auto Sig = encodeFunctionSignature(CurDAG, DL, Returns, Params); + + auto SigOp = CurDAG->getTargetConstant( + Sig, DL, EVT::getIntegerVT(*CurDAG->getContext(), Sig.getBitWidth())); + MachineSDNode *RefTestNode = CurDAG->getMachineNode( + WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, {SigOp, FuncRef}); + ReplaceNode(Node, RefTestNode); + return; + } } break; } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index bf2e04caa0a61..081d09e5b9d31 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -794,6 +794,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB, if (IsIndirect) { // Placeholder for the type index. + // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp MIB.addImm(0); // The table into which this call_indirect indexes. MCSymbolWasm *Table = IsFuncrefCall diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp index cc36244e63ff5..4a57669835af9 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp @@ -15,13 +15,18 @@ #include "WebAssemblyMCInstLower.h" #include "MCTargetDesc/WebAssemblyMCAsmInfo.h" #include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h" #include "TargetInfo/WebAssemblyTargetInfo.h" #include "Utils/WebAssemblyTypeUtilities.h" #include "WebAssemblyAsmPrinter.h" #include "WebAssemblyMachineFunctionInfo.h" #include "WebAssemblyUtilities.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/BinaryFormat/Wasm.h" #include "llvm/CodeGen/AsmPrinter.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineOperand.h" #include "llvm/IR/Constants.h" #include "llvm/MC/MCAsmInfo.h" #include "llvm/MC/MCContext.h" @@ -152,6 +157,34 @@ MCOperand WebAssemblyMCInstLower::lowerTypeIndexOperand( return MCOperand::createExpr(Expr); } +MCOperand +WebAssemblyMCInstLower::lowerEncodedFunctionSignature(const APInt &Sig) const { + // For APInt a word is 64 bits on all architectures, see definition in APInt.h + auto NumWords = Sig.getNumWords(); + SmallVector Params; + SmallVector Returns; + + int Idx = NumWords; + auto GetWord = [&Idx, &Sig]() { + Idx--; + return Sig.extractBitsAsZExtValue(64, 64 * Idx); + }; + // Annoying special case: if getSignificantBits() <= 64 then InstrEmitter will + // emit an Imm instead of a CImm. It simplifies WebAssemblyMCInstLower if we + // always emit a CImm. So xor NParams with 0x7ffffff to ensure + // getSignificantBits() > 64 + // See encodeFunctionSignature in WebAssemblyISelDAGtoDAG.cpp + int NReturns = GetWord() ^ 0x7ffffff; + for (int I = 0; I < NReturns; I++) { + Returns.push_back(static_cast(GetWord())); + } + int NParams = GetWord(); + for (int I = 0; I < NParams; I++) { + Params.push_back(static_cast(GetWord())); + } + return lowerTypeIndexOperand(std::move(Returns), std::move(Params)); +} + static void getFunctionReturns(const MachineInstr *MI, SmallVectorImpl &Returns) { const Function &F = MI->getMF()->getFunction(); @@ -196,11 +229,29 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI, MCOp = MCOperand::createReg(WAReg); break; } + case llvm::MachineOperand::MO_CImmediate: { + // Lower type index placeholder for ref.test + // Currently this is the only way that CImmediates show up so panic if we + // get confused. + unsigned DescIndex = I - NumVariadicDefs; + assert(DescIndex < Desc.NumOperands && "unexpected CImmediate operand"); + auto Operands = Desc.operands(); + const MCOperandInfo &Info = Operands[DescIndex]; + assert(Info.OperandType == WebAssembly::OPERAND_TYPEINDEX && + "unexpected CImmediate operand"); + MCOp = lowerEncodedFunctionSignature(MO.getCImm()->getValue()); + break; + } case MachineOperand::MO_Immediate: { unsigned DescIndex = I - NumVariadicDefs; if (DescIndex < Desc.NumOperands) { - const MCOperandInfo &Info = Desc.operands()[DescIndex]; + auto Operands = Desc.operands(); + const MCOperandInfo &Info = Operands[DescIndex]; + // Replace type index placeholder with actual type index. The type index + // placeholders are Immediates and have an operand type of + // OPERAND_TYPEINDEX or OPERAND_SIGNATURE. if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) { + // Lower type index placeholder for a CALL_INDIRECT instruction SmallVector Returns; SmallVector Params; @@ -228,6 +279,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI, break; } if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) { + // Lower type index placeholder for blocks auto BT = static_cast(MO.getImm()); assert(BT != WebAssembly::BlockType::Invalid); if (BT == WebAssembly::BlockType::Multivalue) { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h index 9f08499e5cde1..34404d93434bb 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h @@ -36,6 +36,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyMCInstLower { MCOperand lowerSymbolOperand(const MachineOperand &MO, MCSymbol *Sym) const; MCOperand lowerTypeIndexOperand(SmallVectorImpl &&, SmallVectorImpl &&) const; + MCOperand lowerEncodedFunctionSignature(const APInt &Sig) const; public: WebAssemblyMCInstLower(MCContext &ctx, WebAssemblyAsmPrinter &printer) diff --git a/llvm/test/CodeGen/WebAssembly/ref-test-func.ll b/llvm/test/CodeGen/WebAssembly/ref-test-func.ll new file mode 100644 index 0000000000000..e3760a07c6445 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/ref-test-func.ll @@ -0,0 +1,120 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s --mtriple=wasm32-unknown-unknown -mcpu=mvp -mattr=+reference-types | FileCheck --check-prefixes CHECK,CHK32 %s +; RUN: llc < %s --mtriple=wasm64-unknown-unknown -mcpu=mvp -mattr=+reference-types | FileCheck --check-prefixes CHECK,CHK64 %s + +define void @test_fpsig_void_void(ptr noundef %func) local_unnamed_addr #0 { +; CHECK-LABEL: test_fpsig_void_void: +; CHK32: .functype test_fpsig_void_void (i32) -> () +; CHK64: .functype test_fpsig_void_void (i64) -> () +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: ref.test () -> () +; CHECK-NEXT: call use +; CHECK-NEXT: # fallthrough-return +entry: + %res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func) + tail call void @use(i32 noundef %res) #3 + ret void +} + +define void @test_fpsig_return_i32(ptr noundef %func) local_unnamed_addr #0 { +; CHECK-LABEL: test_fpsig_return_i32: +; CHK32: .functype test_fpsig_return_i32 (i32) -> () +; CHK64: .functype test_fpsig_return_i32 (i64) -> () +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: ref.test () -> (i32) +; CHECK-NEXT: call use +; CHECK-NEXT: # fallthrough-return +entry: + %res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0) + tail call void @use(i32 noundef %res) #3 + ret void +} + +define void @test_fpsig_return_i64(ptr noundef %func) local_unnamed_addr #0 { +; CHECK-LABEL: test_fpsig_return_i64: +; CHK32: .functype test_fpsig_return_i64 (i32) -> () +; CHK64: .functype test_fpsig_return_i64 (i64) -> () +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: ref.test () -> (i64) +; CHECK-NEXT: call use +; CHECK-NEXT: # fallthrough-return +entry: + %res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i64 0) + tail call void @use(i32 noundef %res) #3 + ret void +} + +define void @test_fpsig_return_f32(ptr noundef %func) local_unnamed_addr #0 { +; CHECK-LABEL: test_fpsig_return_f32: +; CHK32: .functype test_fpsig_return_f32 (i32) -> () +; CHK64: .functype test_fpsig_return_f32 (i64) -> () +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: ref.test () -> (f32) +; CHECK-NEXT: call use +; CHECK-NEXT: # fallthrough-return +entry: + %res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.) + tail call void @use(i32 noundef %res) #3 + ret void +} + +define void @test_fpsig_return_f64(ptr noundef %func) local_unnamed_addr #0 { +; CHECK-LABEL: test_fpsig_return_f64: +; CHK32: .functype test_fpsig_return_f64 (i32) -> () +; CHK64: .functype test_fpsig_return_f64 (i64) -> () +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: ref.test () -> (f64) +; CHECK-NEXT: call use +; CHECK-NEXT: # fallthrough-return +entry: + %res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, double 0.) + tail call void @use(i32 noundef %res) #3 + ret void +} + + +define void @test_fpsig_param_i32(ptr noundef %func) local_unnamed_addr #0 { +; CHECK-LABEL: test_fpsig_param_i32: +; CHK32: .functype test_fpsig_param_i32 (i32) -> () +; CHK64: .functype test_fpsig_param_i32 (i64) -> () +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: ref.test (f64) -> () +; CHECK-NEXT: call use +; CHECK-NEXT: # fallthrough-return +entry: + %res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, double 0.) + tail call void @use(i32 noundef %res) #3 + ret void +} + + +define void @test_fpsig_multiple_params_and_returns(ptr noundef %func) local_unnamed_addr #0 { +; CHECK-LABEL: test_fpsig_multiple_params_and_returns: +; CHK32: .functype test_fpsig_multiple_params_and_returns (i32) -> () +; CHK64: .functype test_fpsig_multiple_params_and_returns (i64) -> () +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: ref.test (i64, f32, i64) -> (i32, i64, f32, f64) +; CHECK-NEXT: call use +; CHECK-NEXT: # fallthrough-return +entry: + %res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i64 0, float 0., double 0., token poison, i64 0, float 0., i64 0) + tail call void @use(i32 noundef %res) #3 + ret void +} + + +declare void @use(i32 noundef) local_unnamed_addr #1