From b6b73da143464cfc4b5786d74a26baceb364cf97 Mon Sep 17 00:00:00 2001 From: Marijn Suijten Date: Wed, 3 Jun 2026 14:57:27 +0200 Subject: [PATCH 1/3] [Metal] Add ray tracing pipeline, SBT, and DispatchRays bring-up MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DXR-style ray tracing reaches Metal through metal_irconverter: each RT entry point is lowered from DXIL to a Metal IR function, raygen is emitted as a kernel (IRRayGenerationCompilationKernel) so it can be dispatched directly, and miss / closest-hit / any-hit / intersection / callable functions are emitted as visible functions and pulled into a MTLVisibleFunctionTable. Implements the three virtuals the foundation PR left stubbed on Metal: • MTLDevice::createPipelineRT compiles every Shaders[] entry against a single IRRayTracingPipelineConfiguration (max attribute/recursion from the YAML RTConfig), builds one MTL::Library per entry, hands the raygen function to the compute pipeline as the kernel, and registers the rest as LinkedFunctions. The freshly-built pipeline then mints a MTLVisibleFunctionTable and resolves each callable function's handle into a slot index that the SBT builder reuses. • MTLDevice::createShaderBindingTable lays the four SBT regions out via the shared computeSBTLayout helper sized for IRShaderIdentifier records, looks up each region entry's ShaderName in the pipeline's name → IRShaderIdentifier map, and memcpys the records into a shared-storage MTL::Buffer the runtime will dereference at dispatch. • MTLComputeEncoder::dispatchRays binds the raygen pipeline and runs dispatchThreads(Width, Height, Depth) on the encoder. The caller (createRayTracingCommands) is responsible for binding the global descriptor heap, top-level argument buffer, IRDispatchRaysArgument (slot 3), and marking the SBT buffer + function tables resident. The IRDispatchRaysArgument struct is built per-dispatch in createRayTracingCommands: SBT region addresses + sizes (read off the MTLShaderBindingTable), GRS / ResDescHeap GPU pointers, and the visible / intersection function table resourceIDs. It's parked in a shared MTL::Buffer kept alive on the command buffer's KeepAlive list and bound at kIRRayDispatchArgumentsBindPoint so callees reached via TraceRay() inherit the same dispatch state through that pointer. Plumbs the existing executeProgram RT branch on Metal the same way the VK / DX backends already do (validate Shaders / SBT / RTConfig, build RayTracingPipelineCreateDesc from the YAML pipeline, create PSO, build SBT, record commands), and adds the raytracing-pipeline lit feature on Metal so test/Feature/RT/raygen-roundtrip.test drops Metal from its XFAIL list and passes natively on Apple Silicon (the 0xBEEF payload roundtrip matches the DX / VK references, verified locally on macOS 15 / metal-irconverter 3.1.1). This PR1 bring-up only handles Triangle hit groups whose only member is a ClosestHit shader — any-hit / intersection / procedural / local root signatures land in follow-ups; createPipelineRT now returns a clear unsupported error for those shapes instead of silently producing wrong output. Co-Authored-By: Claude Opus 4.7 (1M context) --- lib/API/MTL/MTLDevice.cpp | 610 +++++++++++++++++++++- lib/API/MTL/MTLTopLevelArgumentBuffer.cpp | 4 + lib/API/MTL/MTLTopLevelArgumentBuffer.h | 6 + test/Feature/RT/raygen-roundtrip.test | 2 +- test/lit.cfg.py | 4 + 5 files changed, 610 insertions(+), 16 deletions(-) diff --git a/lib/API/MTL/MTLDevice.cpp b/lib/API/MTL/MTLDevice.cpp index 615621091..313b4ff5b 100644 --- a/lib/API/MTL/MTLDevice.cpp +++ b/lib/API/MTL/MTLDevice.cpp @@ -149,12 +149,17 @@ static IRShaderStage getShaderStage(Stages Stage) { case Stages::Mesh: return IRShaderStageMesh; case Stages::RayGeneration: + return IRShaderStageRayGeneration; case Stages::Miss: + return IRShaderStageMiss; case Stages::ClosestHit: + return IRShaderStageClosestHit; case Stages::AnyHit: + return IRShaderStageAnyHit; case Stages::Intersection: + return IRShaderStageIntersection; case Stages::Callable: - llvm_unreachable("RayTracing shaders take a different path on Metal."); + return IRShaderStageCallable; } llvm_unreachable("All cases handled"); } @@ -287,6 +292,10 @@ class MTLPipelineState : public offloadtest::PipelineState { MTL::Size MeshThreadsPerThreadgroup{1, 1, 1}; MTL::Size ObjectThreadsPerThreadgroup{1, 1, 1}; + // True for pipelines created via createPipelineRT; mirrors the VK / DX + // backends' IsRayTracing flag so classof can downcast safely. + bool IsRayTracing = false; + MTLPipelineState(llvm::StringRef Name, IRRootSignaturePtr RootSig, std::unique_ptr ArgBuffer, MTL::ComputePipelineState *ComputePipeline, @@ -321,6 +330,102 @@ class MTLPipelineState : public offloadtest::PipelineState { static bool classof(const offloadtest::PipelineState *B) { return B->getAPI() == GPUAPI::Metal; } + +protected: + // RT subclass constructor — keeps Compute/RenderPipeline null while sharing + // the rest of the layout (Name, root signature, argument buffer). + MTLPipelineState(llvm::StringRef Name, IRRootSignaturePtr RootSig, + std::unique_ptr ArgBuffer, + bool IsRT) + : offloadtest::PipelineState(GPUAPI::Metal), Name(Name), + RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)), + IsRayTracing(IsRT) {} +}; + +/// Ray tracing pipeline state. Layered on top of MTLPipelineState so the +/// existing argument-buffer / root-signature plumbing keeps working; adds the +/// raygen compute pipeline (held in ComputePipeline) plus the +/// IRShaderIdentifier records the SBT builder needs to populate per-record +/// entries. +/// +/// The Metal RT path goes through `metal_irconverter`: +/// • each entry point is compiled to a Metal IR function; +/// • raygen is compiled as a kernel (IRRayGenerationCompilationKernel) so +/// it becomes the compute function of the pipeline; +/// • miss / closest-hit / any-hit / intersection / callable are compiled as +/// visible functions, attached to the pipeline via LinkedFunctions, and +/// looked up by name in a MTLVisibleFunctionTable; +/// • the SBT records IRShaderIdentifier values whose `shaderHandle` is the +/// slot in that visible function table. +class MTLRayTracingPipelineState : public MTLPipelineState { +public: + // ResourceID-based callable tables wired into IRDispatchRaysArgument. + MTL::VisibleFunctionTable *VFT = nullptr; + MTL::IntersectionFunctionTable *IFT = nullptr; + + // Per shader entry / hit-group: pre-built IRShaderIdentifier the SBT + // builder memcpys into each record. Keys are EntryPoint strings for + // raygen / miss / callable shaders and HitGroup.Name for hit groups. + llvm::StringMap ShaderIdentifiers; + + // Keep the per-stage Metal libraries / functions alive as long as the + // pipeline owns the visible-function-table indices that reference them. + llvm::SmallVector Libraries; + llvm::SmallVector Functions; + + MTLRayTracingPipelineState(llvm::StringRef Name, IRRootSignaturePtr RootSig, + std::unique_ptr ArgBuf) + : MTLPipelineState(Name, std::move(RootSig), std::move(ArgBuf), + /*IsRT=*/true) {} + + ~MTLRayTracingPipelineState() override { + if (VFT) + VFT->release(); + if (IFT) + IFT->release(); + for (MTL::Function *F : Functions) + if (F) + F->release(); + for (MTL::Library *L : Libraries) + if (L) + L->release(); + } + + static bool classof(const offloadtest::PipelineState *B) { + if (B->getAPI() != GPUAPI::Metal) + return false; + return static_cast(B)->IsRayTracing; + } +}; + +/// Metal-side shader binding table. There is no `MTLShaderBindingTable` in +/// the Metal API — the irconverter runtime expects the four SBT regions to be +/// laid out as `IRShaderIdentifier` records in a single buffer whose ranges +/// are referenced from an `IRDispatchRaysArgument` struct at dispatch time. +class MTLShaderBindingTable : public offloadtest::ShaderBindingTable { +public: + MTL::Buffer *Buffer = nullptr; + IRVirtualAddressRange RayGenRegion{}; + IRVirtualAddressRangeAndStride MissRegion{}; + IRVirtualAddressRangeAndStride HitGroupRegion{}; + IRVirtualAddressRangeAndStride CallableRegion{}; + + MTLShaderBindingTable(MTL::Buffer *Buf, IRVirtualAddressRange RG, + IRVirtualAddressRangeAndStride MS, + IRVirtualAddressRangeAndStride HG, + IRVirtualAddressRangeAndStride CL) + : offloadtest::ShaderBindingTable(GPUAPI::Metal), Buffer(Buf), + RayGenRegion(RG), MissRegion(MS), HitGroupRegion(HG), + CallableRegion(CL) {} + + ~MTLShaderBindingTable() override { + if (Buffer) + Buffer->release(); + } + + static bool classof(const offloadtest::ShaderBindingTable *S) { + return S->getAPI() == GPUAPI::Metal; + } }; class MTLBuffer : public offloadtest::Buffer { @@ -682,10 +787,39 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder { // MTL::Device handle (used to allocate scratch and instance buffers). llvm::Error batchBuildAS(llvm::ArrayRef Items) override; - llvm::Error dispatchRays(const PipelineState &, const ShaderBindingTable &, - uint32_t, uint32_t, uint32_t) override { - return llvm::createStringError( - "RayTracing dispatchRays not yet supported on Metal"); + // Dispatch threads using a raygen compute kernel synthesized by the + // irconverter. All bindings (descriptor heap, top-level argument buffer, + // IRDispatchRaysArgument at slot 3, visible/intersection function tables, + // and the SBT buffer) must already be set on the active compute encoder by + // the caller — this method only binds the pipeline state and issues the + // dispatch. + llvm::Error dispatchRays(const PipelineState &PSO, const ShaderBindingTable &, + uint32_t Width, uint32_t Height, + uint32_t Depth) override { + if (!llvm::isa(&PSO)) + return llvm::createStringError( + std::errc::invalid_argument, + "dispatchRays requires a RayTracing PipelineState."); + const auto &RTPSO = llvm::cast(PSO); + if (!RTPSO.ComputePipeline) + return llvm::createStringError( + std::errc::invalid_argument, + "RayTracing PipelineState has no compute pipeline state."); + if (auto Err = ensureComputeEncoder()) + return Err; + flushBarrier(); + insertDebugSignpost( + llvm::formatv("DispatchRays [{0},{1},{2}]", Width, Height, Depth) + .str()); + ComputeEnc->setComputePipelineState(RTPSO.ComputePipeline); + + // DispatchRays(W, H, D) launches W*H*D rays; tid in the irconverter raygen + // kernel is the per-ray index. Pass grid as raw (W, H, D) and let Metal + // ceil-divide by ThreadsPerGroup to compute threadgroup count. + const MTL::Size GridSize(Width, Height, Depth); + ComputeEnc->dispatchThreads(GridSize, RTPSO.ThreadsPerGroup); + addBarrierScope(MTL::BarrierScopeBuffers | MTL::BarrierScopeTextures); + return llvm::Error::success(); } /// Lazily transition into an AccelerationStructureCommandEncoder; mirrors @@ -1048,6 +1182,7 @@ class MTLDevice : public offloadtest::Device { std::unique_ptr DepthStencil; std::unique_ptr CB; std::unique_ptr Pipeline; + std::unique_ptr SBT; std::unique_ptr RenderPass; llvm::SmallVector DescTables; @@ -1227,6 +1362,61 @@ class MTLDevice : public offloadtest::Device { return MetalIR{std::move(MetalLib), std::move(Reflection)}; } + // Compile a single ray-tracing entry point out of a DXIL library to a Metal + // library + reflection. The compiler is configured with the global root + // signature and a IRRayTracingPipelineConfiguration that mirrors the + // pipeline's RTConfig — raygen is forced to kernel-mode compilation so it + // becomes the compute function on the pipeline state, while every other + // RT stage is emitted as a visible function callable from the raygen kernel + // through a MTLVisibleFunctionTable. + llvm::Expected + convertRTShaderToMetalIR(Stages Stage, IRRootSignature *RootSig, + const IRRayTracingPipelineConfiguration *RTConfig, + llvm::StringRef Entry, + const llvm::MemoryBuffer &Library) { + IRCompilerPtr Compiler(IRCompilerCreate()); + if (!Compiler) + return llvm::createStringError(std::errc::not_supported, + "Failed to create IR compiler instance."); + if (!RootSig) + return llvm::createStringError( + std::errc::invalid_argument, + "Root signature must be created before converting to Metal IR."); + + IRCompilerSetEntryPointName(Compiler.get(), std::string(Entry).c_str()); + IRCompilerSetGlobalRootSignature(Compiler.get(), RootSig); + IRCompilerSetRayTracingPipelineConfiguration(Compiler.get(), RTConfig); + + const llvm::StringRef Bytes = Library.getBuffer(); + IRObjectPtr DXIL( + IRObjectCreateFromDXIL(reinterpret_cast(Bytes.data()), + Bytes.size(), IRBytecodeOwnershipNone)); + + IRError *Err = nullptr; + IRObjectPtr ResultIR(IRCompilerAllocCompileAndLink( + Compiler.get(), std::string(Entry).c_str(), DXIL.get(), &Err)); + if (Err) + return toError(IRErrorPtr(Err).get(), + "Failed to compile RT shader to Metal IR"); + + const IRShaderStage ShaderStage = getShaderStage(Stage); + auto MetalLib = IRMetalLibBinaryPtr(IRMetalLibBinaryCreate()); + if (!IRObjectGetMetalLibBinary(ResultIR.get(), ShaderStage, MetalLib.get())) + return llvm::createStringError( + std::errc::not_supported, + "Failed to retrieve Metal library binary for RT entry '%s'", + std::string(Entry).c_str()); + + auto Reflection = IRShaderReflectionPtr(IRShaderReflectionCreate()); + if (!IRObjectGetReflection(ResultIR.get(), ShaderStage, Reflection.get())) + return llvm::createStringError( + std::errc::not_supported, + "Failed to retrieve RT shader reflection for entry '%s'", + std::string(Entry).c_str()); + + return MetalIR{std::move(MetalLib), std::move(Reflection)}; + } + // Creates a Metal resource (buffer or texture) for the given Resource at the // specified array index. llvm::Expected @@ -1591,6 +1781,101 @@ class MTLDevice : public offloadtest::Device { return llvm::Error::success(); } + llvm::Error createRayTracingCommands(Pipeline &P, InvocationState &IS) { + auto EncoderOrErr = IS.CB->createComputeEncoder(); + if (!EncoderOrErr) + return EncoderOrErr.takeError(); + auto &Encoder = llvm::cast(*EncoderOrErr.get()); + MTL::ComputeCommandEncoder *NativeEncoder = Encoder.getNative(); + + const auto &RTPSO = + llvm::cast(*IS.Pipeline.get()); + const auto &SBT = llvm::cast(*IS.SBT.get()); + + // Bind the global descriptor heap + top-level argument buffer the same + // way the compute path does; the raygen kernel and any visible-function + // callees consume them at the same slots (kIRDescriptorHeapBindPoint and + // kIRArgumentBufferBindPoint). + MTLGPUDescriptorHandle Handle = {}; + if (IS.DescHeap) { + IS.DescHeap->bind(NativeEncoder); + Handle = IS.DescHeap->getGPUDescriptorHandleForHeapStart(); + } + for (uint32_t Idx = 0u; Idx < P.Sets.size(); ++Idx) { + RTPSO.ArgBuffer->setRootDescriptorTable(Idx, Handle); + Handle.addOffset(P.Sets[Idx].Resources.size()); + } + RTPSO.ArgBuffer->bind(NativeEncoder); + + // Populate the per-dispatch IRDispatchRaysArgument: SBT region addresses + // (RayGen / Miss / HitGroup / Callable), GPU pointers to the global + // root-signature argument buffer + descriptor heaps, plus resource IDs + // for the visible / intersection function tables. The raygen kernel + // reads this struct from the buffer bound at kIRRayDispatchArgumentsBind- + // Point and any visible-function callees inherit it through the same + // pointer. + IRDispatchRaysArgument Args{}; + Args.DispatchRaysDesc.RayGenerationShaderRecord = SBT.RayGenRegion; + Args.DispatchRaysDesc.MissShaderTable = SBT.MissRegion; + Args.DispatchRaysDesc.HitGroupTable = SBT.HitGroupRegion; + Args.DispatchRaysDesc.CallableShaderTable = SBT.CallableRegion; + Args.DispatchRaysDesc.Width = P.DispatchParameters.DispatchGroupCount[0]; + Args.DispatchRaysDesc.Height = P.DispatchParameters.DispatchGroupCount[1]; + Args.DispatchRaysDesc.Depth = P.DispatchParameters.DispatchGroupCount[2]; + Args.GRS = RTPSO.ArgBuffer->getGPUAddress(); + Args.ResDescHeap = + IS.DescHeap ? IS.DescHeap->getGPUDescriptorHandleForHeapStart().Ptr : 0; + Args.SmpDescHeap = 0; + Args.VisibleFunctionTable = + RTPSO.VFT ? RTPSO.VFT->gpuResourceID() : MTL::ResourceID{0}; + Args.IntersectionFunctionTable = + RTPSO.IFT ? RTPSO.IFT->gpuResourceID() : MTL::ResourceID{0}; + Args.IntersectionFunctionTables = 0; + + MTL::Buffer *ArgsBuf = Device->newBuffer( + &Args, sizeof(IRDispatchRaysArgument), MTL::ResourceStorageModeShared); + if (!ArgsBuf) + return llvm::createStringError( + std::errc::not_enough_memory, + "Failed to allocate IRDispatchRaysArgument buffer."); + IS.CB->KeepAliveMTLBuffers.push_back(ArgsBuf); + NativeEncoder->setBuffer(ArgsBuf, 0, kIRRayDispatchArgumentsBindPoint); + NativeEncoder->useResource(ArgsBuf, MTL::ResourceUsageRead); + + // Mark every dispatch-side resource resident: descriptor-table bundles, + // acceleration structures + their irconverter header/contribution + // buffers (so RayQuery/TraceRay can read them), the SBT buffer (the + // raygen kernel dereferences SBT addresses), and the visible / + // intersection function tables. + for (const auto &Table : IS.DescTables) + for (const auto &ResPair : Table.Resources) + for (const auto &ResSet : ResPair.second) + NativeEncoder->useResource(ResSet.Resource.get(), + MTL::ResourceUsageRead | + MTL::ResourceUsageWrite); + for (auto &AS : IS.AccelStructs) { + auto *MTLAS = llvm::cast(AS.get()); + NativeEncoder->useResource(MTLAS->AccelStruct, MTL::ResourceUsageRead); + } + for (MTL::Buffer *B : IS.ASDescriptorBuffers) + NativeEncoder->useResource(B, MTL::ResourceUsageRead); + if (SBT.Buffer) + NativeEncoder->useResource(SBT.Buffer, MTL::ResourceUsageRead); + if (RTPSO.VFT) + NativeEncoder->useResource(RTPSO.VFT, MTL::ResourceUsageRead); + if (RTPSO.IFT) + NativeEncoder->useResource(RTPSO.IFT, MTL::ResourceUsageRead); + + if (auto Err = + Encoder.dispatchRays(*IS.Pipeline.get(), *IS.SBT.get(), + P.DispatchParameters.DispatchGroupCount[0], + P.DispatchParameters.DispatchGroupCount[1], + P.DispatchParameters.DispatchGroupCount[2])) + return Err; + Encoder.endEncoding(); + return llvm::Error::success(); + } + llvm::Error createRenderTarget(Pipeline &P, InvocationState &IS) { if (!P.Bindings.RTargetBufferPtr) return llvm::createStringError( @@ -1770,17 +2055,283 @@ class MTLDevice : public offloadtest::Device { Queue &getGraphicsQueue() override { return GraphicsQueue; } llvm::Expected> - createPipelineRT(llvm::StringRef, const BindingsDesc &, - const RayTracingPipelineCreateDesc &) override { - return llvm::createStringError( - "RayTracing pipeline state not yet supported on Metal"); + createPipelineRT(llvm::StringRef Name, const BindingsDesc &BD, + const RayTracingPipelineCreateDesc &Desc) override { + if (!Device->supportsRaytracing()) + return llvm::createStringError( + std::errc::not_supported, + "Ray tracing is not supported on this Metal device."); + if (!Desc.Library) + return llvm::createStringError(std::errc::invalid_argument, + "RayTracingPipelineCreateDesc.Library is " + "null — backend needs a DXIL blob."); + + IRRootSignaturePtr RootSig; + std::unique_ptr ArgBuffer; + if (auto Err = + createRootSignature(BD, /*IsGraphics=*/false, RootSig, ArgBuffer)) + return Err; + + // Configure the irconverter ray tracing pipeline. Raygen is compiled as a + // kernel so it can be dispatched directly; miss / closest-hit / any-hit / + // intersection / callable shaders are compiled as visible functions and + // looked up via a MTLVisibleFunctionTable at runtime. + auto RTConfig = + std::unique_ptr>( + IRRayTracingPipelineConfigurationCreate()); + if (!RTConfig) + return llvm::createStringError( + std::errc::not_supported, + "Failed to create IRRayTracingPipelineConfiguration."); + IRRayTracingPipelineConfigurationSetMaxAttributeSizeInBytes( + RTConfig.get(), Desc.Config.MaxAttributeSizeInBytes); + IRRayTracingPipelineConfigurationSetMaxRecursiveDepth( + RTConfig.get(), static_cast(Desc.Config.MaxTraceRecursionDepth)); + IRRayTracingPipelineConfigurationSetRayGenerationCompilationMode( + RTConfig.get(), IRRayGenerationCompilationKernel); + IRRayTracingPipelineConfigurationSetIntersectionFunctionCompilationMode( + RTConfig.get(), IRIntersectionFunctionCompilationVisibleFunction); + + auto State = std::make_unique( + Name, std::move(RootSig), std::move(ArgBuffer)); + + // Compile each entry point. Raygen lands in `RaygenFn` (becomes the + // compute function); everything else gets linked in via LinkedFunctions + // and indexed by visible-function-table slot. + MTLPtr RaygenFn; + std::string RaygenEntry; + llvm::SmallVector VisibleFns; + llvm::StringMap EntryToVFTIndex; + MTL::Size RaygenThreadsPerGroup(1, 1, 1); + for (const auto &Sh : Desc.Shaders) { + auto IROrErr = convertRTShaderToMetalIR(Sh.Stage, State->RootSig.get(), + RTConfig.get(), Sh.EntryPoint, + *Desc.Library); + if (!IROrErr) + return IROrErr.takeError(); + + dispatch_data_t Data = IRMetalLibGetBytecodeData(IROrErr->Binary.get()); + NS::Error *Error = nullptr; + MTL::Library *Lib = Device->newLibrary(Data, &Error); + if (Error) + return toError(Error); + State->Libraries.push_back(Lib); + + MTL::Function *Fn = Lib->newFunction( + NS::String::string(Sh.EntryPoint.c_str(), NS::UTF8StringEncoding)); + if (!Fn) + return llvm::createStringError( + std::errc::invalid_argument, + "Failed to find RT entry point '%s' in compiled Metal library.", + Sh.EntryPoint.c_str()); + + if (Sh.Stage == Stages::RayGeneration) { + RaygenFn.reset(Fn); + RaygenEntry = Sh.EntryPoint; + IRVersionedCSInfo Info; + if (IRShaderReflectionCopyComputeInfo(IROrErr->Reflection.get(), + IRReflectionVersion_1_0, &Info)) { + RaygenThreadsPerGroup = + MTL::Size(Info.info_1_0.tg_size[0], Info.info_1_0.tg_size[1], + Info.info_1_0.tg_size[2]); + IRShaderReflectionReleaseComputeInfo(&Info); + } + } else { + const uint32_t Slot = static_cast(VisibleFns.size()); + VisibleFns.push_back(Fn); + EntryToVFTIndex[Sh.EntryPoint] = Slot; + State->Functions.push_back(Fn); + } + } + if (!RaygenFn) + return llvm::createStringError( + std::errc::invalid_argument, + "RayTracing pipeline requires at least one RayGeneration shader."); + + // Pre-build IRShaderIdentifier records for every name the SBT can + // reference. Raygen records carry no shader handle (the kernel is + // dispatched directly); miss / closest-hit / callable carry their + // visible-function-table index; hit groups reuse the closest-hit + // index since this PR1 bring-up only supports HitGroupType::Triangles + // without AnyHit/Intersection. + IRShaderIdentifier RaygenIdent{}; + IRShaderIdentifierInit(&RaygenIdent, /*shaderHandle=*/0); + State->ShaderIdentifiers[RaygenEntry] = RaygenIdent; + + for (const auto &Sh : Desc.Shaders) { + if (Sh.Stage == Stages::RayGeneration) + continue; + auto It = EntryToVFTIndex.find(Sh.EntryPoint); + assert(It != EntryToVFTIndex.end() && "missing visible-function index"); + IRShaderIdentifier Ident{}; + IRShaderIdentifierInit(&Ident, It->second); + State->ShaderIdentifiers[Sh.EntryPoint] = Ident; + } + + for (const auto &HG : Desc.HitGroups) { + if (HG.AnyHit || HG.Intersection) + return llvm::createStringError( + std::errc::not_supported, + "Metal RT bring-up only supports Triangle hit groups with a " + "ClosestHit shader; AnyHit/Intersection support is not " + "implemented yet."); + auto It = EntryToVFTIndex.find(HG.ClosestHit); + if (It == EntryToVFTIndex.end()) + return llvm::createStringError( + std::errc::invalid_argument, + "Hit group '%s' references unknown ClosestHit shader '%s'.", + HG.Name.c_str(), HG.ClosestHit.c_str()); + IRShaderIdentifier Ident{}; + IRShaderIdentifierInit(&Ident, It->second); + State->ShaderIdentifiers[HG.Name] = Ident; + } + + // Pipeline descriptor: raygen as the compute function, everything else as + // linked functions reachable from the visible function table. + MTLPtr Desc2( + MTL::ComputePipelineDescriptor::alloc()->init()); + Desc2->setComputeFunction(RaygenFn.get()); + Desc2->setLabel( + NS::String::string(std::string(Name).c_str(), NS::UTF8StringEncoding)); + // setMaxCallStackDepth defaults to 1, allowing the raygen kernel exactly + // one level of visible-function call. Nested TraceRay (depth ≥ 2) needs + // the call stack to nest at least that deep — match what the YAML + // RTConfig declares so recursive RT pipelines work. + Desc2->setMaxCallStackDepth(Desc.Config.MaxTraceRecursionDepth); + if (!VisibleFns.empty()) { + MTLPtr Linked( + MTL::LinkedFunctions::alloc()->init()); + NS::Array *FnArr = NS::Array::array( + reinterpret_cast(VisibleFns.data()), + VisibleFns.size()); + Linked->setFunctions(FnArr); + Desc2->setLinkedFunctions(Linked.get()); + } + + NS::Error *Error = nullptr; + MTL::ComputePipelineState *PSO = Device->newComputePipelineState( + Desc2.get(), MTL::PipelineOptionNone, /*reflection=*/nullptr, &Error); + if (Error) + return toError(Error); + + // Populate the visible function table from function handles obtained on + // the freshly-created pipeline. + if (!VisibleFns.empty()) { + MTLPtr VFTDesc( + MTL::VisibleFunctionTableDescriptor::alloc()->init()); + VFTDesc->setFunctionCount(VisibleFns.size()); + MTL::VisibleFunctionTable *VFT = + PSO->newVisibleFunctionTable(VFTDesc.get()); + if (!VFT) { + PSO->release(); + return llvm::createStringError( + std::errc::device_or_resource_busy, + "Failed to create MTL::VisibleFunctionTable for RT pipeline."); + } + for (uint32_t I = 0; I < VisibleFns.size(); ++I) { + MTL::FunctionHandle *H = PSO->functionHandle(VisibleFns[I]); + if (!H) { + VFT->release(); + PSO->release(); + return llvm::createStringError( + std::errc::not_supported, + "Pipeline has no FunctionHandle for linked function index %u.", + I); + } + VFT->setFunction(H, I); + } + State->VFT = VFT; + } + + State->ComputePipeline = PSO; + State->ThreadsPerGroup = RaygenThreadsPerGroup; + return State; } llvm::Expected> - createShaderBindingTable(const PipelineState &, - const ShaderBindingTableDesc &) override { - return llvm::createStringError( - "RayTracing shader binding table not yet supported on Metal"); + createShaderBindingTable(const PipelineState &PSO, + const ShaderBindingTableDesc &Desc) override { + if (!llvm::isa(&PSO)) + return llvm::createStringError( + std::errc::invalid_argument, + "createShaderBindingTable requires a RayTracing PipelineState."); + const auto &RTPSO = llvm::cast(PSO); + + // Layout: four concatenated regions of IRShaderIdentifier-sized records. + // computeSBTLayout aligns record stride/region size to the values we + // pass. Metal does not expose explicit record/table alignment knobs the + // way D3D12 does — pick natural alignment (16 bytes) so the irconverter + // runtime's pointer reads stay aligned. + constexpr uint32_t IdSize = + static_cast(sizeof(IRShaderIdentifier)); + constexpr uint32_t RecordAlign = 16; + constexpr uint32_t BaseAlign = 16; + const SBTLayout Layout = + computeSBTLayout(IdSize, RecordAlign, BaseAlign, Desc); + const uint32_t TotalSize = Layout.TotalSize; + const llvm::ArrayRef RGEntries(&Desc.RayGen, 1); + + MTL::Buffer *Buffer = + Device->newBuffer(TotalSize, MTL::ResourceStorageModeShared); + if (!Buffer) + return llvm::createStringError(std::errc::not_enough_memory, + "Failed to allocate Metal SBT buffer."); + auto *Mapped = static_cast(Buffer->contents()); + std::memset(Mapped, 0, TotalSize); + + auto WriteEntries = [&](uint8_t *Region, llvm::ArrayRef Entries, + uint32_t Stride) -> llvm::Error { + for (size_t I = 0; I < Entries.size(); ++I) { + const auto &E = Entries[I]; + auto It = RTPSO.ShaderIdentifiers.find(E.ShaderName); + if (It == RTPSO.ShaderIdentifiers.end()) { + Buffer->release(); + return llvm::createStringError( + std::errc::invalid_argument, + "SBT references unknown shader/hit-group name: '%s'", + E.ShaderName.c_str()); + } + uint8_t *Dst = Region + I * Stride; + std::memcpy(Dst, &It->second, sizeof(IRShaderIdentifier)); + if (!E.LocalRootData.empty()) + std::memcpy(Dst + sizeof(IRShaderIdentifier), E.LocalRootData.data(), + E.LocalRootData.size()); + } + return llvm::Error::success(); + }; + + if (auto Err = WriteEntries(Mapped + Layout.RayGen.Offset, RGEntries, + Layout.RayGen.Stride)) + return Err; + if (auto Err = WriteEntries(Mapped + Layout.Miss.Offset, Desc.Miss, + Layout.Miss.Stride)) + return Err; + if (auto Err = WriteEntries(Mapped + Layout.HitGroup.Offset, Desc.HitGroup, + Layout.HitGroup.Stride)) + return Err; + if (auto Err = WriteEntries(Mapped + Layout.Callable.Offset, Desc.Callable, + Layout.Callable.Stride)) + return Err; + + const uint64_t Base = Buffer->gpuAddress(); + auto MakeRange = [&](const SBTRegionLayout &R) { + IRVirtualAddressRange V{}; + V.StartAddress = R.Size ? Base + R.Offset : 0; + V.SizeInBytes = R.Size; + return V; + }; + auto MakeRangeAndStride = [&](const SBTRegionLayout &R) { + IRVirtualAddressRangeAndStride V{}; + V.StartAddress = R.Size ? Base + R.Offset : 0; + V.SizeInBytes = R.Size; + V.StrideInBytes = R.Stride; + return V; + }; + return std::make_unique( + Buffer, MakeRange(Layout.RayGen), MakeRangeAndStride(Layout.Miss), + MakeRangeAndStride(Layout.HitGroup), + MakeRangeAndStride(Layout.Callable)); } llvm::Expected> @@ -2503,8 +3054,37 @@ class MTLDevice : public offloadtest::Device { if (auto Err = createGraphicsCommands(P, IS)) return Err; } else if (P.isRayTracing()) { - return llvm::createStringError( - "RayTracing pipeline not yet supported on Metal"); + if (P.Shaders.empty() || !P.SBT || !P.RTConfig) + return llvm::createStringError( + std::errc::invalid_argument, + "RayTracing pipeline requires Shaders, " + "ShaderBindingTable, and RayTracingPipelineConfig."); + + RayTracingPipelineCreateDesc RTDesc{}; + // All RT shader entries share the single DXIL library blob fanned out + // by the offloader CLI for RT pipelines. + RTDesc.Library = P.Shaders.front().Shader.get(); + RTDesc.HitGroups = P.HitGroups; + RTDesc.Config = *P.RTConfig; + RTDesc.Shaders.reserve(P.Shaders.size()); + for (const auto &Sh : P.Shaders) + RTDesc.Shaders.push_back({Sh.Stage, Sh.Entry}); + + auto PSOOrErr = + createPipelineRT("RayTracing Pipeline State", Bindings, RTDesc); + if (!PSOOrErr) + return PSOOrErr.takeError(); + IS.Pipeline = std::move(*PSOOrErr); + llvm::outs() << "RayTracing Pipeline created.\n"; + + auto SBTOrErr = createShaderBindingTable(*IS.Pipeline, *P.SBT); + if (!SBTOrErr) + return SBTOrErr.takeError(); + IS.SBT = std::move(*SBTOrErr); + llvm::outs() << "Shader Binding Table created.\n"; + + if (auto Err = createRayTracingCommands(P, IS)) + return Err; } auto SubmitResult = GraphicsQueue.submit(std::move(IS.CB)); diff --git a/lib/API/MTL/MTLTopLevelArgumentBuffer.cpp b/lib/API/MTL/MTLTopLevelArgumentBuffer.cpp index cd240d433..063c55053 100644 --- a/lib/API/MTL/MTLTopLevelArgumentBuffer.cpp +++ b/lib/API/MTL/MTLTopLevelArgumentBuffer.cpp @@ -188,3 +188,7 @@ void MTLTopLevelArgumentBuffer::bind( Encoder->useResource(Buffer, MTL::ResourceUsageRead); Encoder->setBuffer(Buffer, 0, kIRArgumentBufferBindPoint); } + +uint64_t MTLTopLevelArgumentBuffer::getGPUAddress() const { + return Buffer ? Buffer->gpuAddress() : 0; +} diff --git a/lib/API/MTL/MTLTopLevelArgumentBuffer.h b/lib/API/MTL/MTLTopLevelArgumentBuffer.h index 42bc1abf0..50c878603 100644 --- a/lib/API/MTL/MTLTopLevelArgumentBuffer.h +++ b/lib/API/MTL/MTLTopLevelArgumentBuffer.h @@ -64,6 +64,12 @@ class MTLTopLevelArgumentBuffer { void bind(MTL::RenderCommandEncoder *Encoder) const; // Bind the argument buffer to the compute command encoder. void bind(MTL::ComputeCommandEncoder *Encoder) const; + + // GPU address of the underlying buffer. Returns 0 when the root signature is + // empty and no buffer was allocated. Needed by the ray-tracing path so the + // synthesized IRDispatchRaysArgument can point shader-record callees back at + // the GRS (global root signature top-level argument buffer). + uint64_t getGPUAddress() const; }; } // namespace offloadtest diff --git a/test/Feature/RT/raygen-roundtrip.test b/test/Feature/RT/raygen-roundtrip.test index 6baff2921..86eb01469 100644 --- a/test/Feature/RT/raygen-roundtrip.test +++ b/test/Feature/RT/raygen-roundtrip.test @@ -106,7 +106,7 @@ Results: # REQUIRES: raytracing-pipeline # Unimplemented https://github.com/llvm/offload-test-suite/issues/1268 -# XFAIL: Clang, Metal +# XFAIL: Clang # RUN: split-file %s %t # RUN: %dxc_target_lib -T lib_6_5 -Fo %t.o %t/source.hlsl diff --git a/test/lit.cfg.py b/test/lit.cfg.py index d23d351ae..a87f3e9b0 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -197,6 +197,10 @@ def setDeviceFeatures(config, device, compiler): config.available_features.add("MeshShader") if device["Features"].get("supportsRaytracing", False): config.available_features.add("acceleration-structure") + # The Metal RT pipeline path lowers DXIL → Metal IR via + # metal_irconverter; gate on the same Raytracing-tier capability + # as inline RT, matching the DX / VK plumbing-vs-impl split. + config.available_features.add("raytracing-pipeline") if device["API"] == "Vulkan": if device["Features"].get("shaderInt16", False): From b832a78080c8ee6daba75dda7cc934f27a2841ad Mon Sep 17 00:00:00 2001 From: EmilioLaiso Date: Tue, 23 Jun 2026 17:30:24 +0200 Subject: [PATCH 2/3] fix ci --- lib/API/MTL/MTLDevice.cpp | 41 +++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/lib/API/MTL/MTLDevice.cpp b/lib/API/MTL/MTLDevice.cpp index 313b4ff5b..2607440d6 100644 --- a/lib/API/MTL/MTLDevice.cpp +++ b/lib/API/MTL/MTLDevice.cpp @@ -1832,15 +1832,19 @@ class MTLDevice : public offloadtest::Device { RTPSO.IFT ? RTPSO.IFT->gpuResourceID() : MTL::ResourceID{0}; Args.IntersectionFunctionTables = 0; - MTL::Buffer *ArgsBuf = Device->newBuffer( - &Args, sizeof(IRDispatchRaysArgument), MTL::ResourceStorageModeShared); - if (!ArgsBuf) - return llvm::createStringError( - std::errc::not_enough_memory, - "Failed to allocate IRDispatchRaysArgument buffer."); - IS.CB->KeepAliveMTLBuffers.push_back(ArgsBuf); - NativeEncoder->setBuffer(ArgsBuf, 0, kIRRayDispatchArgumentsBindPoint); - NativeEncoder->useResource(ArgsBuf, MTL::ResourceUsageRead); + const BufferCreateDesc ArgsBufDesc = BufferCreateDesc::uploadBuffer(); + auto ArgsBufOrErr = offloadtest::createBufferWithData( + *CB->Dev, "MTL Dispatch Rays Arguments", ArgsBufDesc, &Args, + sizeof(IRDispatchRaysArgument), nullptr, nullptr); + if (!ArgsBufOrErr) + return ArgsBufOrErr.takeError(); + + auto *MTLArgsBuf = llvm::cast(ArgsBufOrErr->get()); + CB->KeepAliveOwned.push_back(std::move(*ArgsBufOrErr)); + + NativeEncoder->setBuffer(MTLArgsBuf->Buf, 0, + kIRRayDispatchArgumentsBindPoint); + NativeEncoder->useResource(MTLArgsBuf->Buf, MTL::ResourceUsageRead); // Mark every dispatch-side resource resident: descriptor-table bundles, // acceleration structures + their irconverter header/contribution @@ -1853,12 +1857,19 @@ class MTLDevice : public offloadtest::Device { NativeEncoder->useResource(ResSet.Resource.get(), MTL::ResourceUsageRead | MTL::ResourceUsageWrite); - for (auto &AS : IS.AccelStructs) { - auto *MTLAS = llvm::cast(AS.get()); - NativeEncoder->useResource(MTLAS->AccelStruct, MTL::ResourceUsageRead); - } - for (MTL::Buffer *B : IS.ASDescriptorBuffers) - NativeEncoder->useResource(B, MTL::ResourceUsageRead); + auto MarkASResident = + [&](std::unique_ptr &AS) { + auto *MTLAS = llvm::cast(AS.get()); + NativeEncoder->useResource(MTLAS->AccelStruct, + MTL::ResourceUsageRead); + }; + for (auto &AS : IS.BLASes) + MarkASResident(AS); + for (auto &Entry : IS.TLASes) + MarkASResident(Entry.second); + for (auto &B : IS.ASDescriptorBuffers) + NativeEncoder->useResource(llvm::cast(B.get())->Buf, + MTL::ResourceUsageRead); if (SBT.Buffer) NativeEncoder->useResource(SBT.Buffer, MTL::ResourceUsageRead); if (RTPSO.VFT) From 48fdc902c091c2cbc48786d1926323ee3a4149f7 Mon Sep 17 00:00:00 2001 From: EmilioLaiso Date: Tue, 23 Jun 2026 17:49:20 +0200 Subject: [PATCH 3/3] fix ci --- lib/API/MTL/MTLDevice.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/API/MTL/MTLDevice.cpp b/lib/API/MTL/MTLDevice.cpp index 2607440d6..0b3b35509 100644 --- a/lib/API/MTL/MTLDevice.cpp +++ b/lib/API/MTL/MTLDevice.cpp @@ -1834,13 +1834,13 @@ class MTLDevice : public offloadtest::Device { const BufferCreateDesc ArgsBufDesc = BufferCreateDesc::uploadBuffer(); auto ArgsBufOrErr = offloadtest::createBufferWithData( - *CB->Dev, "MTL Dispatch Rays Arguments", ArgsBufDesc, &Args, + *IS.CB->Dev, "MTL Dispatch Rays Arguments", ArgsBufDesc, &Args, sizeof(IRDispatchRaysArgument), nullptr, nullptr); if (!ArgsBufOrErr) return ArgsBufOrErr.takeError(); auto *MTLArgsBuf = llvm::cast(ArgsBufOrErr->get()); - CB->KeepAliveOwned.push_back(std::move(*ArgsBufOrErr)); + IS.CB->KeepAliveOwned.push_back(std::move(*ArgsBufOrErr)); NativeEncoder->setBuffer(MTLArgsBuf->Buf, 0, kIRRayDispatchArgumentsBindPoint);