Skip to content

[WebAssembly,llvm] Add llvm.wasm.ref.test.func intrinsic #147486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsWebAssembly.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
"llvm.wasm.ref.is_null.exn">;

def int_wasm_ref_test_func
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
[IntrNoMem]>;

//===----------------------------------------------------------------------===//
// Table intrinsics
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,12 @@ void InstrEmitter::AddOperand(MachineInstrBuilder &MIB, SDValue Op,
AddRegisterOperand(MIB, Op, IIOpNum, II, VRBaseMap,
IsDebug, IsClone, IsCloned);
} else if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) {
MIB.addImm(C->getSExtValue());
if (C->getAPIntValue().getSignificantBits() <= 64) {
MIB.addImm(C->getSExtValue());
} else {
MIB.addCImm(
ConstantInt::get(MF->getFunction().getContext(), C->getAPIntValue()));
}
} else if (ConstantFPSDNode *F = dyn_cast<ConstantFPSDNode>(Op)) {
MIB.addFPImm(F->getConstantFPValue());
} else if (RegisterSDNode *R = dyn_cast<RegisterSDNode>(Op)) {
Expand Down
91 changes: 91 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
#include "WebAssembly.h"
#include "WebAssemblyISelLowering.h"
#include "WebAssemblyTargetMachine.h"
#include "WebAssemblyUtilities.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/SelectionDAGISel.h"
#include "llvm/CodeGen/WasmEHFuncInfo.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Function.h" // To access function attributes.
#include "llvm/IR/IntrinsicsWebAssembly.h"
#include "llvm/MC/MCSymbolWasm.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -118,6 +120,51 @@ static SDValue getTagSymNode(int Tag, SelectionDAG *DAG) {
return DAG->getTargetExternalSymbol(SymName, PtrVT);
}

static APInt encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL,
SmallVector<MVT, 4> &Returns,
SmallVector<MVT, 4> &Params) {
auto toWasmValType = [&DAG, &DL](MVT VT) {
if (VT == MVT::i32) {
return wasm::ValType::I32;
}
if (VT == MVT::i64) {
return wasm::ValType::I64;
}
if (VT == MVT::f32) {
return wasm::ValType::F32;
}
if (VT == MVT::f64) {
return wasm::ValType::F64;
}
LLVM_DEBUG(errs() << "Unhandled type for llvm.wasm.ref.test.func: " << VT
<< "\n");
llvm_unreachable("Unhandled type for llvm.wasm.ref.test.func");
};
auto NParams = Params.size();
auto NReturns = Returns.size();
auto BitWidth = (NParams + NReturns + 2) * 64;
auto Sig = APInt(BitWidth, 0);

// Annoying special case: if getSignificantBits() <= 64 then InstrEmitter will
// emit an Imm instead of a CImm. It simplifies WebAssemblyMCInstLower if we
// always emit a CImm. So xor NParams with 0x7ffffff to ensure
// getSignificantBits() > 64
Sig |= NReturns ^ 0x7ffffff;
for (auto &Return : Returns) {
auto V = toWasmValType(Return);
Sig <<= 64;
Sig |= (int64_t)V;
}
Sig <<= 64;
Sig |= NParams;
for (auto &Param : Params) {
auto V = toWasmValType(Param);
Sig <<= 64;
Sig |= (int64_t)V;
}
return Sig;
}

void WebAssemblyDAGToDAGISel::Select(SDNode *Node) {
// If we have a custom node, we already have selected!
if (Node->isMachineOpcode()) {
Expand Down Expand Up @@ -189,6 +236,50 @@ void WebAssemblyDAGToDAGISel::Select(SDNode *Node) {
ReplaceNode(Node, TLSAlign);
return;
}
case Intrinsic::wasm_ref_test_func: {
// First emit the TABLE_GET instruction to convert function pointer ==>
// funcref
MachineFunction &MF = CurDAG->getMachineFunction();
auto PtrVT = MVT::getIntegerVT(MF.getDataLayout().getPointerSizeInBits());
MCSymbol *Table = WebAssembly::getOrCreateFunctionTableSymbol(
MF.getContext(), Subtarget);
SDValue TableSym = CurDAG->getMCSymbol(Table, PtrVT);
SDValue FuncRef = SDValue(
CurDAG->getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
MVT::funcref, TableSym, Node->getOperand(1)),
0);

// Encode the signature information into the type index placeholder.
// This gets decoded and converted into the actual type signature in
// WebAssemblyMCInstLower.cpp.
SmallVector<MVT, 4> Params;
SmallVector<MVT, 4> Returns;

bool IsParam = false;
// Operand 0 is the return register, Operand 1 is the function pointer.
// The remaining operands encode the type of the function we are testing
// for.
for (unsigned I = 2, E = Node->getNumOperands(); I < E; ++I) {
MVT VT = Node->getOperand(I).getValueType().getSimpleVT();
if (VT == MVT::Untyped) {
IsParam = true;
continue;
}
if (IsParam) {
Params.push_back(VT);
} else {
Returns.push_back(VT);
}
}
auto Sig = encodeFunctionSignature(CurDAG, DL, Returns, Params);

auto SigOp = CurDAG->getTargetConstant(
Sig, DL, EVT::getIntegerVT(*CurDAG->getContext(), Sig.getBitWidth()));
MachineSDNode *RefTestNode = CurDAG->getMachineNode(
WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, {SigOp, FuncRef});
ReplaceNode(Node, RefTestNode);
return;
}
}
break;
}
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,

if (IsIndirect) {
// Placeholder for the type index.
// This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
MIB.addImm(0);
// The table into which this call_indirect indexes.
MCSymbolWasm *Table = IsFuncrefCall
Expand Down
54 changes: 53 additions & 1 deletion llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
#include "WebAssemblyMCInstLower.h"
#include "MCTargetDesc/WebAssemblyMCAsmInfo.h"
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
#include "TargetInfo/WebAssemblyTargetInfo.h"
#include "Utils/WebAssemblyTypeUtilities.h"
#include "WebAssemblyAsmPrinter.h"
#include "WebAssemblyMachineFunctionInfo.h"
#include "WebAssemblyUtilities.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/IR/Constants.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCContext.h"
Expand Down Expand Up @@ -152,6 +157,34 @@ MCOperand WebAssemblyMCInstLower::lowerTypeIndexOperand(
return MCOperand::createExpr(Expr);
}

MCOperand
WebAssemblyMCInstLower::lowerEncodedFunctionSignature(const APInt &Sig) const {
// For APInt a word is 64 bits on all architectures, see definition in APInt.h
auto NumWords = Sig.getNumWords();
SmallVector<wasm::ValType, 4> Params;
SmallVector<wasm::ValType, 2> Returns;

int Idx = NumWords;
auto GetWord = [&Idx, &Sig]() {
Idx--;
return Sig.extractBitsAsZExtValue(64, 64 * Idx);
};
// Annoying special case: if getSignificantBits() <= 64 then InstrEmitter will
// emit an Imm instead of a CImm. It simplifies WebAssemblyMCInstLower if we
// always emit a CImm. So xor NParams with 0x7ffffff to ensure
// getSignificantBits() > 64
// See encodeFunctionSignature in WebAssemblyISelDAGtoDAG.cpp
int NReturns = GetWord() ^ 0x7ffffff;
for (int I = 0; I < NReturns; I++) {
Returns.push_back(static_cast<wasm::ValType>(GetWord()));
}
int NParams = GetWord();
for (int I = 0; I < NParams; I++) {
Params.push_back(static_cast<wasm::ValType>(GetWord()));
}
return lowerTypeIndexOperand(std::move(Returns), std::move(Params));
}

static void getFunctionReturns(const MachineInstr *MI,
SmallVectorImpl<wasm::ValType> &Returns) {
const Function &F = MI->getMF()->getFunction();
Expand Down Expand Up @@ -196,11 +229,29 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
MCOp = MCOperand::createReg(WAReg);
break;
}
case llvm::MachineOperand::MO_CImmediate: {
// Lower type index placeholder for ref.test
// Currently this is the only way that CImmediates show up so panic if we
// get confused.
unsigned DescIndex = I - NumVariadicDefs;
assert(DescIndex < Desc.NumOperands && "unexpected CImmediate operand");
auto Operands = Desc.operands();
const MCOperandInfo &Info = Operands[DescIndex];
assert(Info.OperandType == WebAssembly::OPERAND_TYPEINDEX &&
"unexpected CImmediate operand");
MCOp = lowerEncodedFunctionSignature(MO.getCImm()->getValue());
break;
}
case MachineOperand::MO_Immediate: {
unsigned DescIndex = I - NumVariadicDefs;
if (DescIndex < Desc.NumOperands) {
const MCOperandInfo &Info = Desc.operands()[DescIndex];
auto Operands = Desc.operands();
const MCOperandInfo &Info = Operands[DescIndex];
// Replace type index placeholder with actual type index. The type index
// placeholders are Immediates and have an operand type of
// OPERAND_TYPEINDEX or OPERAND_SIGNATURE.
if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) {
// Lower type index placeholder for a CALL_INDIRECT instruction
SmallVector<wasm::ValType, 4> Returns;
SmallVector<wasm::ValType, 4> Params;

Expand Down Expand Up @@ -228,6 +279,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
break;
}
if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) {
// Lower type index placeholder for blocks
auto BT = static_cast<WebAssembly::BlockType>(MO.getImm());
assert(BT != WebAssembly::BlockType::Invalid);
if (BT == WebAssembly::BlockType::Multivalue) {
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyMCInstLower {
MCOperand lowerSymbolOperand(const MachineOperand &MO, MCSymbol *Sym) const;
MCOperand lowerTypeIndexOperand(SmallVectorImpl<wasm::ValType> &&,
SmallVectorImpl<wasm::ValType> &&) const;
MCOperand lowerEncodedFunctionSignature(const APInt &Sig) const;

public:
WebAssemblyMCInstLower(MCContext &ctx, WebAssemblyAsmPrinter &printer)
Expand Down
120 changes: 120 additions & 0 deletions llvm/test/CodeGen/WebAssembly/ref-test-func.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s --mtriple=wasm32-unknown-unknown -mcpu=mvp -mattr=+reference-types | FileCheck --check-prefixes CHECK,CHK32 %s
; RUN: llc < %s --mtriple=wasm64-unknown-unknown -mcpu=mvp -mattr=+reference-types | FileCheck --check-prefixes CHECK,CHK64 %s

define void @test_fpsig_void_void(ptr noundef %func) local_unnamed_addr #0 {
; CHECK-LABEL: test_fpsig_void_void:
; CHK32: .functype test_fpsig_void_void (i32) -> ()
; CHK64: .functype test_fpsig_void_void (i64) -> ()
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: ref.test () -> ()
; CHECK-NEXT: call use
; CHECK-NEXT: # fallthrough-return
entry:
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func)
tail call void @use(i32 noundef %res) #3
ret void
}

define void @test_fpsig_return_i32(ptr noundef %func) local_unnamed_addr #0 {
; CHECK-LABEL: test_fpsig_return_i32:
; CHK32: .functype test_fpsig_return_i32 (i32) -> ()
; CHK64: .functype test_fpsig_return_i32 (i64) -> ()
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: ref.test () -> (i32)
; CHECK-NEXT: call use
; CHECK-NEXT: # fallthrough-return
entry:
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0)
tail call void @use(i32 noundef %res) #3
ret void
}

define void @test_fpsig_return_i64(ptr noundef %func) local_unnamed_addr #0 {
; CHECK-LABEL: test_fpsig_return_i64:
; CHK32: .functype test_fpsig_return_i64 (i32) -> ()
; CHK64: .functype test_fpsig_return_i64 (i64) -> ()
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: ref.test () -> (i64)
; CHECK-NEXT: call use
; CHECK-NEXT: # fallthrough-return
entry:
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i64 0)
tail call void @use(i32 noundef %res) #3
ret void
}

define void @test_fpsig_return_f32(ptr noundef %func) local_unnamed_addr #0 {
; CHECK-LABEL: test_fpsig_return_f32:
; CHK32: .functype test_fpsig_return_f32 (i32) -> ()
; CHK64: .functype test_fpsig_return_f32 (i64) -> ()
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: ref.test () -> (f32)
; CHECK-NEXT: call use
; CHECK-NEXT: # fallthrough-return
entry:
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.)
tail call void @use(i32 noundef %res) #3
ret void
}

define void @test_fpsig_return_f64(ptr noundef %func) local_unnamed_addr #0 {
; CHECK-LABEL: test_fpsig_return_f64:
; CHK32: .functype test_fpsig_return_f64 (i32) -> ()
; CHK64: .functype test_fpsig_return_f64 (i64) -> ()
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: ref.test () -> (f64)
; CHECK-NEXT: call use
; CHECK-NEXT: # fallthrough-return
entry:
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, double 0.)
tail call void @use(i32 noundef %res) #3
ret void
}


define void @test_fpsig_param_i32(ptr noundef %func) local_unnamed_addr #0 {
; CHECK-LABEL: test_fpsig_param_i32:
; CHK32: .functype test_fpsig_param_i32 (i32) -> ()
; CHK64: .functype test_fpsig_param_i32 (i64) -> ()
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: ref.test (f64) -> ()
; CHECK-NEXT: call use
; CHECK-NEXT: # fallthrough-return
entry:
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, double 0.)
tail call void @use(i32 noundef %res) #3
ret void
}


define void @test_fpsig_multiple_params_and_returns(ptr noundef %func) local_unnamed_addr #0 {
; CHECK-LABEL: test_fpsig_multiple_params_and_returns:
; CHK32: .functype test_fpsig_multiple_params_and_returns (i32) -> ()
; CHK64: .functype test_fpsig_multiple_params_and_returns (i64) -> ()
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: ref.test (i64, f32, i64) -> (i32, i64, f32, f64)
; CHECK-NEXT: call use
; CHECK-NEXT: # fallthrough-return
entry:
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i64 0, float 0., double 0., token poison, i64 0, float 0., i64 0)
tail call void @use(i32 noundef %res) #3
ret void
}


declare void @use(i32 noundef) local_unnamed_addr #1
Loading