From d9145bf2a9e66a2fd7815e29f35a9caf2a5f4e65 Mon Sep 17 00:00:00 2001 From: usuyus Date: Wed, 17 Sep 2025 16:27:07 +0100 Subject: [PATCH] statepoints --- clang/lib/CodeGen/BackendUtil.cpp | 12 + llvm/include/llvm/CodeGen/AsmPrinter.h | 2 +- llvm/include/llvm/CodeGen/GCMetadataPrinter.h | 2 +- .../llvm/CodeGen/LinkAllAsmWriterComponents.h | 1 + llvm/include/llvm/CodeGen/StackMaps.h | 12 +- llvm/include/llvm/IR/BuiltinGCs.h | 3 + llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp | 8 +- llvm/lib/CodeGen/AsmPrinter/CMakeLists.txt | 1 + .../CodeGen/AsmPrinter/OxCamlGCPrinter.cpp | 246 ++++++++++++++++++ llvm/lib/CodeGen/StackMaps.cpp | 13 +- llvm/lib/IR/BuiltinGCs.cpp | 17 ++ .../lib/Transforms/Scalar/PlaceSafepoints.cpp | 4 +- .../Scalar/RewriteStatepointsForGC.cpp | 12 +- 13 files changed, 317 insertions(+), 16 deletions(-) create mode 100644 llvm/lib/CodeGen/AsmPrinter/OxCamlGCPrinter.cpp diff --git a/clang/lib/CodeGen/BackendUtil.cpp b/clang/lib/CodeGen/BackendUtil.cpp index 10d6bff25e6d6..57659123b08d8 100644 --- a/clang/lib/CodeGen/BackendUtil.cpp +++ b/clang/lib/CodeGen/BackendUtil.cpp @@ -84,6 +84,7 @@ #include "llvm/Transforms/Scalar/GVN.h" #include "llvm/Transforms/Scalar/JumpThreading.h" #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" +#include "llvm/Transforms/Scalar/RewriteStatepointsForGC.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/CanonicalizeAliases.h" #include "llvm/Transforms/Utils/Debugify.h" @@ -954,6 +955,17 @@ void EmitAssemblyHelper::RunOptimizationPipeline( }); } + // TODO: Do this in a location that is more appropriate (LLVM instead of + // Clang). Also, determine a better place for this in the pipeline, since + // we don't want other transformations to treat values that may be relocated + // by the GC in an unsound way. + PB.registerOptimizerLastEPCallback( + [](ModulePassManager &MPM, OptimizationLevel Level) { + if (Level != OptimizationLevel::O0) { + MPM.addPass(RewriteStatepointsForGC()); + } + }); + // Register callbacks to schedule sanitizer passes at the appropriate part // of the pipeline. if (LangOpts.Sanitize.has(SanitizerKind::LocalBounds)) diff --git a/llvm/include/llvm/CodeGen/AsmPrinter.h b/llvm/include/llvm/CodeGen/AsmPrinter.h index 33fda248120bd..e0d81a344c69a 100644 --- a/llvm/include/llvm/CodeGen/AsmPrinter.h +++ b/llvm/include/llvm/CodeGen/AsmPrinter.h @@ -524,7 +524,7 @@ class AsmPrinter : public MachineFunctionPass { void emitGlobalGOTEquivs(); /// Emit the stack maps. - void emitStackMaps(); + void emitStackMaps(Module &M); //===------------------------------------------------------------------===// // Overridable Hooks diff --git a/llvm/include/llvm/CodeGen/GCMetadataPrinter.h b/llvm/include/llvm/CodeGen/GCMetadataPrinter.h index f9527c9f8752e..ba3aff3f772b3 100644 --- a/llvm/include/llvm/CodeGen/GCMetadataPrinter.h +++ b/llvm/include/llvm/CodeGen/GCMetadataPrinter.h @@ -64,7 +64,7 @@ class GCMetadataPrinter { /// Called when the stack maps are generated. Return true if /// stack maps with a custom format are generated. Otherwise /// returns false and the default format will be used. - virtual bool emitStackMaps(StackMaps &SM, AsmPrinter &AP) { return false; } + virtual bool emitStackMaps(Module &M, StackMaps &SM, AsmPrinter &AP) { return false; } }; } // end namespace llvm diff --git a/llvm/include/llvm/CodeGen/LinkAllAsmWriterComponents.h b/llvm/include/llvm/CodeGen/LinkAllAsmWriterComponents.h index c22f9d49f374b..9b86b0809a93c 100644 --- a/llvm/include/llvm/CodeGen/LinkAllAsmWriterComponents.h +++ b/llvm/include/llvm/CodeGen/LinkAllAsmWriterComponents.h @@ -32,6 +32,7 @@ namespace { llvm::linkOcamlGCPrinter(); llvm::linkErlangGCPrinter(); + llvm::linkOxCamlGCPrinter(); } } ForceAsmWriterLinking; // Force link by creating a global definition. diff --git a/llvm/include/llvm/CodeGen/StackMaps.h b/llvm/include/llvm/CodeGen/StackMaps.h index 467e31f17bc82..e7d3a28e0ea16 100644 --- a/llvm/include/llvm/CodeGen/StackMaps.h +++ b/llvm/include/llvm/CodeGen/StackMaps.h @@ -310,23 +310,29 @@ class StackMaps { using ConstantPool = MapVector; struct FunctionInfo { + uint64_t StaticStackSize = 0; uint64_t StackSize = 0; uint64_t RecordCount = 1; FunctionInfo() = default; - explicit FunctionInfo(uint64_t StackSize) : StackSize(StackSize) {} + explicit FunctionInfo(uint64_t StaticStackSize, uint64_t StackSize) + : StaticStackSize(StaticStackSize), StackSize(StackSize) {} }; struct CallsiteInfo { + const MCSymbol *CSLabel = nullptr; const MCExpr *CSOffsetExpr = nullptr; + const FunctionInfo CSFunctionInfo; uint64_t ID = 0; LocationVec Locations; LiveOutVec LiveOuts; CallsiteInfo() = default; - CallsiteInfo(const MCExpr *CSOffsetExpr, uint64_t ID, + CallsiteInfo(const MCSymbol *CSLabel, const MCExpr *CSOffsetExpr, + const FunctionInfo CSFunctionInfo, uint64_t ID, LocationVec &&Locations, LiveOutVec &&LiveOuts) - : CSOffsetExpr(CSOffsetExpr), ID(ID), Locations(std::move(Locations)), + : CSLabel(CSLabel), CSOffsetExpr(CSOffsetExpr), + CSFunctionInfo(CSFunctionInfo), ID(ID), Locations(std::move(Locations)), LiveOuts(std::move(LiveOuts)) {} }; diff --git a/llvm/include/llvm/IR/BuiltinGCs.h b/llvm/include/llvm/IR/BuiltinGCs.h index 16aff01dbcf3d..654ab16b309ac 100644 --- a/llvm/include/llvm/IR/BuiltinGCs.h +++ b/llvm/include/llvm/IR/BuiltinGCs.h @@ -28,6 +28,9 @@ void linkOcamlGCPrinter(); /// Creates an erlang-compatible metadata printer. void linkErlangGCPrinter(); +/// Creates an oxcaml-compatible metadata printer. +void linkOxCamlGCPrinter(); + } // namespace llvm #endif // LLVM_IR_BUILTINGCS_H diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index 0b1e32c87fc34..c9187cc6da8f1 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -39,6 +39,8 @@ #include "llvm/BinaryFormat/ELF.h" #include "llvm/CodeGen/GCMetadata.h" #include "llvm/CodeGen/GCMetadataPrinter.h" +#include "llvm/IR/BuiltinGCs.h" +#include "llvm/CodeGen/LinkAllAsmWriterComponents.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineDominators.h" @@ -2187,7 +2189,7 @@ bool AsmPrinter::doFinalization(Module &M) { // text sections come after debug info has been emitted. This matters for // stack maps as they are arbitrary data, and may even have a custom format // through user plugins. - emitStackMaps(); + emitStackMaps(M); // Finalize debug and EH information. for (const HandlerInfo &HI : Handlers) { @@ -3857,7 +3859,7 @@ GCMetadataPrinter *AsmPrinter::getOrCreateGCPrinter(GCStrategy &S) { report_fatal_error("no GCMetadataPrinter registered for GC: " + Twine(Name)); } -void AsmPrinter::emitStackMaps() { +void AsmPrinter::emitStackMaps(Module &M) { GCModuleInfo *MI = getAnalysisIfAvailable(); assert(MI && "AsmPrinter didn't require GCModuleInfo?"); bool NeedsDefault = false; @@ -3867,7 +3869,7 @@ void AsmPrinter::emitStackMaps() { else for (const auto &I : *MI) { if (GCMetadataPrinter *MP = getOrCreateGCPrinter(*I)) - if (MP->emitStackMaps(SM, *this)) + if (MP->emitStackMaps(M, SM, *this)) continue; // The strategy doesn't have printer or doesn't emit custom stack maps. // Use the default format. diff --git a/llvm/lib/CodeGen/AsmPrinter/CMakeLists.txt b/llvm/lib/CodeGen/AsmPrinter/CMakeLists.txt index 410e120d0e1bd..fd3babf91946b 100644 --- a/llvm/lib/CodeGen/AsmPrinter/CMakeLists.txt +++ b/llvm/lib/CodeGen/AsmPrinter/CMakeLists.txt @@ -21,6 +21,7 @@ add_llvm_component_library(LLVMAsmPrinter EHStreamer.cpp ErlangGCPrinter.cpp OcamlGCPrinter.cpp + OxCamlGCPrinter.cpp PseudoProbePrinter.cpp WinCFGuard.cpp WinException.cpp diff --git a/llvm/lib/CodeGen/AsmPrinter/OxCamlGCPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/OxCamlGCPrinter.cpp new file mode 100644 index 0000000000000..cc4ae59d04fcb --- /dev/null +++ b/llvm/lib/CodeGen/AsmPrinter/OxCamlGCPrinter.cpp @@ -0,0 +1,246 @@ +//===- OxCamlGCPrinter.cpp - OxCaml frametable emitter --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements printing the assembly code for an OxCaml frametable. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/Twine.h" +#include "llvm/CodeGen/AsmPrinter.h" +#include "llvm/CodeGen/GCMetadata.h" +#include "llvm/CodeGen/GCMetadataPrinter.h" +#include "llvm/CodeGen/StackMaps.h" +#include "llvm/IR/BuiltinGCs.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Mangler.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Statepoint.h" +#include "llvm/MC/MCContext.h" +#include "llvm/MC/MCDirectives.h" +#include "llvm/MC/MCStreamer.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Target/TargetLoweringObjectFile.h" +#include +#include +#include +#include +#include + +using namespace llvm; + +namespace { + +class OxCamlGCMetadataPrinter : public GCMetadataPrinter { +public: + void beginAssembly(Module &M, GCModuleInfo &Info, AsmPrinter &AP) override; + void finishAssembly(Module &M, GCModuleInfo &Info, AsmPrinter &AP) override; + bool emitStackMaps(Module &M, StackMaps &SM, AsmPrinter &AP) override; +}; + +} // end anonymous namespace + +static GCMetadataPrinterRegistry::Add + Y("oxcaml", "OxCaml frametable printer"); + +void llvm::linkOxCamlGCPrinter() {} + +static std::string camlGlobalSymName(const Module &M, const char *Id) { + if (Metadata *ModuleMD = M.getModuleFlag(StringRef("oxcaml_module"))) { + if (MDString *Str = dyn_cast(ModuleMD)) { + StringRef ModuleName = Str->getString(); + + std::string SymName; + SymName += "caml"; + SymName += ModuleName; + SymName += "__"; + SymName += Id; + + return SymName; + } + } + + report_fatal_error("Module name not provided for OxCaml GC!"); +} + +static void emitCamlGlobal(const Module &M, MCStreamer &OS, const char *Id) { + std::string SymName = camlGlobalSymName(M, Id); + + SmallString<128> TmpStr; + Mangler::getNameWithPrefix(TmpStr, SymName, M.getDataLayout()); + + MCSymbol *Sym = OS.getContext().getOrCreateSymbol(TmpStr); + + OS.emitSymbolAttribute(Sym, MCSA_Global); + OS.emitLabel(Sym); +} + +void OxCamlGCMetadataPrinter::beginAssembly(Module &M, GCModuleInfo &Info, + AsmPrinter &AP) { + AP.OutStreamer->switchSection(AP.getObjFileLowering().getTextSection()); + emitCamlGlobal(M, *(AP.OutStreamer), "code_begin"); + + AP.OutStreamer->switchSection(AP.getObjFileLowering().getDataSection()); + emitCamlGlobal(M, *(AP.OutStreamer), "data_begin"); +} + + +void OxCamlGCMetadataPrinter::finishAssembly(Module &M, GCModuleInfo &Info, + AsmPrinter &AP) { + AP.OutStreamer->switchSection(AP.getObjFileLowering().getTextSection()); + emitCamlGlobal(M, *(AP.OutStreamer), "code_end"); + + AP.OutStreamer->switchSection(AP.getObjFileLowering().getDataSection()); + emitCamlGlobal(M, *(AP.OutStreamer), "data_end"); +} + +/// Map LLVM DWARF register numbers to OxCaml register map. +/// * See llvm/lib/Target/X86/X86RegisterInfo.td for DWARF register numbers. +/// * See backend/amd64/proc.ml for the OxCaml register map. + +// TODO: This is target-specific and should probably live in a +// target-specific location. + +// Directly taken from [Reg_class.gpr_dwarf_reg_numbers]: +// https://github.com/oxcaml/oxcaml/blob/main/backend/amd64/reg_class.ml#L26 +// Note that R14 and R15 are added for completeness +static constexpr std::array GPR_OxCamlToDwarf = + { 0, 3, 5, 4, 1, 2, 8, 9, 12, 13, 10, 11, 6, 14, 15 }; + +static constexpr auto GPR_DwarfToOxCaml = []() { + std::array result{}; + for (size_t ocaml_idx = 0; ocaml_idx < GPR_OxCamlToDwarf.size(); ++ocaml_idx) { + unsigned dwarf_reg = GPR_OxCamlToDwarf[ocaml_idx]; + if (dwarf_reg < result.size()) { + result[dwarf_reg] = ocaml_idx; + } + } + return result; +}(); + +static const unsigned XMMBeginOxCaml = 100; +static const unsigned XMMBeginDwarf = 17; +static const unsigned XMMEndDwarf = 32; + +static unsigned mapLLVMDwarfRegToOxCamlIndex(unsigned DwarfRegNum) { + if (DwarfRegNum < GPR_DwarfToOxCaml.size()) { + return GPR_DwarfToOxCaml[DwarfRegNum]; + } else if (XMMBeginDwarf <= DwarfRegNum && DwarfRegNum <= XMMEndDwarf) { + return DwarfRegNum - XMMBeginDwarf + XMMBeginOxCaml; + } else { + report_fatal_error("Unrecognised DWARF register for use in OxCaml frametable: " + + Twine(DwarfRegNum)); + } +} + +bool OxCamlGCMetadataPrinter::emitStackMaps(Module &M, StackMaps &SM, AsmPrinter &AP) { + MCStreamer &OS = *AP.OutStreamer; + unsigned PtrSize = M.getDataLayout().getPointerSize(); // Can only be 8 for now + + OS.switchSection(AP.getObjFileLowering().getDataSection()); + + emitCamlGlobal(M, OS, "frametable"); + + // Number of records + OS.emitInt64(SM.getCSInfos().size()); + + for (const auto &CSI : SM.getCSInfos()) { + // From runtime/frame_descriptors.h: + // https://github.com/oxcaml/oxcaml/blob/main/runtime/caml/frame_descriptors.h#L63 + // + // typedef struct { + // int32_t retaddr_rel; /* offset of return address from &retaddr_rel */ + // uint16_t frame_data; /* frame size and various flags */ + // uint16_t num_live; + // uint16_t live_ofs[num_live]; + // } frame_descr; + + // retaddr_rel + MCSymbol *Here = OS.getContext().createTempSymbol(); + OS.emitLabel(Here); + const MCExpr *RelativeAddr = MCBinaryExpr::createSub( + MCSymbolRefExpr::create(CSI.CSLabel, OS.getContext()), + MCSymbolRefExpr::create(Here, OS.getContext()), + OS.getContext()); + OS.emitValue(RelativeAddr, 4); + + // frame_data + uint64_t FrameSize = CSI.CSFunctionInfo.StaticStackSize; + if (CSI.ID != StatepointDirectives::DefaultStatepointID) + FrameSize += CSI.ID; // Stack offset from OxCaml + FrameSize += PtrSize; // Return address + + if (FrameSize >= 1 << 16) + report_fatal_error("Long frames not supported for OxCaml GC: FrameSize = " + + Twine(FrameSize)); + OS.emitInt16(FrameSize); + + // num_live + uint64_t LiveCount = 0; + for (const auto &Loc : CSI.Locations) { + if (Loc.Type == StackMaps::Location::Register || + Loc.Type == StackMaps::Location::Direct || + Loc.Type == StackMaps::Location::Indirect) { + LiveCount++; + } + } + LiveCount += CSI.LiveOuts.size(); + + if (LiveCount >= 1 << 16) { + // Very rude! + report_fatal_error("Long frames not supported for OxCaml GC: LiveCount = " + + Twine(LiveCount)); + } + OS.emitInt16(LiveCount); + + // live_ofs + for (const auto &Loc : CSI.Locations) { + if (Loc.Type == StackMaps::Location::Register) { + // Register indices are tagged (2n+1) and follow the OxCaml register + // map (see `mapLLVMDwarfRegToOxCamlIndex`) + unsigned DwarfRegNum = Loc.Reg; + unsigned OxCamlIndex = mapLLVMDwarfRegToOxCamlIndex(DwarfRegNum); + uint16_t EncodedReg = (OxCamlIndex << 1) + 1; + OS.emitInt16(EncodedReg); + } else if (Loc.Type == StackMaps::Location::Direct || + Loc.Type == StackMaps::Location::Indirect) { + // For stack locations (Direct/Indirect): emit offset directly + int64_t Offset = Loc.Offset; + + // BP-relative addressing -> SP + if (Offset < 0) { + int64_t TempFrameSize = + FrameSize - PtrSize /* return address */ - PtrSize /* pushed BP */; + Offset += TempFrameSize; + } + + if (Offset < -(1 << 15) || Offset >= (1 << 15)) { + // Very rude! + report_fatal_error("Stack offset too large for OxCaml frametable: " + + Twine(Offset)); + } + OS.emitInt16(static_cast(Offset)); + } else { + // TODO: Do we need anything else here? + } + } + + for (const auto &LO : CSI.LiveOuts) { + unsigned OxCamlIndex = mapLLVMDwarfRegToOxCamlIndex(LO.DwarfRegNum); + uint16_t EncodedReg = (OxCamlIndex << 1) + 1; + OS.emitInt16(EncodedReg); + } + + OS.emitValueToAlignment(Align(PtrSize)); + } + + OS.addBlankLine(); + return true; +} diff --git a/llvm/lib/CodeGen/StackMaps.cpp b/llvm/lib/CodeGen/StackMaps.cpp index bb7a51e49edb7..1c2ce591483b6 100644 --- a/llvm/lib/CodeGen/StackMaps.cpp +++ b/llvm/lib/CodeGen/StackMaps.cpp @@ -521,21 +521,24 @@ void StackMaps::recordStackMapOpers(const MCSymbol &MILabel, MCSymbolRefExpr::create(&MILabel, OutContext), MCSymbolRefExpr::create(AP.CurrentFnSymForSize, OutContext), OutContext); - CSInfos.emplace_back(CSOffsetExpr, ID, std::move(Locations), - std::move(LiveOuts)); - // Record the stack size of the current function and update callsite count. const MachineFrameInfo &MFI = AP.MF->getFrameInfo(); const TargetRegisterInfo *RegInfo = AP.MF->getSubtarget().getRegisterInfo(); bool HasDynamicFrameSize = MFI.hasVarSizedObjects() || RegInfo->hasStackRealignment(*(AP.MF)); - uint64_t FrameSize = HasDynamicFrameSize ? UINT64_MAX : MFI.getStackSize(); + uint64_t StaticFrameSize = MFI.getStackSize(); + uint64_t FrameSize = HasDynamicFrameSize ? UINT64_MAX : StaticFrameSize; auto CurrentIt = FnInfos.find(AP.CurrentFnSym); if (CurrentIt != FnInfos.end()) CurrentIt->second.RecordCount++; else - FnInfos.insert(std::make_pair(AP.CurrentFnSym, FunctionInfo(FrameSize))); + FnInfos.insert(std::make_pair(AP.CurrentFnSym, + FunctionInfo(StaticFrameSize, FrameSize))); + + CSInfos.emplace_back(&MILabel, CSOffsetExpr, + FunctionInfo(StaticFrameSize, FrameSize), + ID, std::move(Locations), std::move(LiveOuts)); } void StackMaps::recordStackMap(const MCSymbol &L, const MachineInstr &MI) { diff --git a/llvm/lib/IR/BuiltinGCs.cpp b/llvm/lib/IR/BuiltinGCs.cpp index 163b0383e22c2..85f38cc44a109 100644 --- a/llvm/lib/IR/BuiltinGCs.cpp +++ b/llvm/lib/IR/BuiltinGCs.cpp @@ -43,6 +43,22 @@ class OcamlGC : public GCStrategy { } }; + +class OxCamlGC : public GCStrategy { +public: + OxCamlGC() { + UseStatepoints = true; + UseRS4GC = true; + NeededSafePoints = false; + UsesMetadata = true; // This is for custom frametable printing. + } + + std::optional isGCManagedPointer(const Type *Ty) const override { + const PointerType *PT = cast(Ty); + return (1 == PT->getAddressSpace()); + } +}; + /// A GC strategy for uncooperative targets. This implements lowering for the /// llvm.gc* intrinsics for targets that do not natively support them (which /// includes the C backend). Note that the code generated is not quite as @@ -127,6 +143,7 @@ static GCRegistry::Add static GCRegistry::Add D("statepoint-example", "an example strategy for statepoint"); static GCRegistry::Add E("coreclr", "CoreCLR-compatible GC"); +static GCRegistry::Add F("oxcaml", "OxCaml-compatible GC"); // Provide hook to ensure the containing library is fully loaded. void llvm::linkAllBuiltinGCs() {} diff --git a/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp b/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp index e1cc3fc71c3e4..5417acb6fabc4 100644 --- a/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp +++ b/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -451,8 +451,10 @@ static bool shouldRewriteFunction(Function &F) { const auto &FunctionGCName = F.getGC(); const StringRef StatepointExampleName("statepoint-example"); const StringRef CoreCLRName("coreclr"); + const StringRef OxCamlName("oxcaml"); return (StatepointExampleName == FunctionGCName) || - (CoreCLRName == FunctionGCName); + (CoreCLRName == FunctionGCName) || + (OxCamlName == FunctionGCName); } else return false; } diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index bcb012b79c2e0..f83ae3d7ae021 100644 --- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -356,8 +356,14 @@ static bool containsGCPtrType(Type *Ty) { return isGCPointerType(VT->getScalarType()); if (ArrayType *AT = dyn_cast(Ty)) return containsGCPtrType(AT->getElementType()); + + // Don't fail on structs - in our use case, we extract all its elements right + // after they get created. Make sure to check this assumption still holds + // with new changes. + // TODO: Make this check conditional on OxCaml GC if (StructType *ST = dyn_cast(Ty)) - return llvm::any_of(ST->elements(), containsGCPtrType); + return false; // llvm::any_of(ST->elements(), containsGCPtrType); + return false; } @@ -3034,8 +3040,10 @@ static bool shouldRewriteStatepointsIn(Function &F) { const auto &FunctionGCName = F.getGC(); const StringRef StatepointExampleName("statepoint-example"); const StringRef CoreCLRName("coreclr"); + const StringRef OxCamlName("oxcaml"); return (StatepointExampleName == FunctionGCName) || - (CoreCLRName == FunctionGCName); + (CoreCLRName == FunctionGCName) || + (OxCamlName == FunctionGCName); } else return false; }