diff --git a/include/slang.h b/include/slang.h index 9755415b3f..777cd406b3 100644 --- a/include/slang.h +++ b/include/slang.h @@ -603,6 +603,7 @@ extern "C" SLANG_METAL_LIB, ///< Metal library SLANG_METAL_LIB_ASM, ///< Metal library assembly SLANG_HOST_SHARED_LIBRARY, ///< A shared library/Dll for host code (for hosting CPU/OS) + SLANG_WGSL, ///< WebGPU shading language SLANG_TARGET_COUNT_OF, }; @@ -735,6 +736,7 @@ extern "C" SLANG_SOURCE_LANGUAGE_CUDA, SLANG_SOURCE_LANGUAGE_SPIRV, SLANG_SOURCE_LANGUAGE_METAL, + SLANG_SOURCE_LANGUAGE_WGSL, SLANG_SOURCE_LANGUAGE_COUNT_OF, }; diff --git a/source/compiler-core/slang-artifact-desc-util.cpp b/source/compiler-core/slang-artifact-desc-util.cpp index a4190992cf..9794cc90e9 100644 --- a/source/compiler-core/slang-artifact-desc-util.cpp +++ b/source/compiler-core/slang-artifact-desc-util.cpp @@ -197,6 +197,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactKind, SLANG_ARTIFACT_KIND, SLANG_ARTIFACT_KIND_E x(CUDA, Source) \ x(Metal, Source) \ x(Slang, Source) \ + x(WGSL, Source) \ x(KernelLike, Base) \ x(DXIL, KernelLike) \ x(DXBC, KernelLike) \ @@ -288,6 +289,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL case SLANG_METAL: return Desc::make(Kind::Source, Payload::Metal, Style::Kernel, 0); case SLANG_METAL_LIB: return Desc::make(Kind::Executable, Payload::MetalAIR, Style::Kernel, 0); case SLANG_METAL_LIB_ASM: return Desc::make(Kind::Assembly, Payload::MetalAIR, Style::Kernel, 0); + case SLANG_WGSL: return Desc::make(Kind::Source, Payload::WGSL, Style::Kernel, 0); default: break; } @@ -330,6 +332,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL case Payload::Cpp: return (desc.style == Style::Host) ? SLANG_HOST_CPP_SOURCE : SLANG_CPP_SOURCE; case Payload::CUDA: return SLANG_CUDA_SOURCE; case Payload::Metal: return SLANG_METAL; + case Payload::WGSL: return SLANG_WGSL; default: break; } break; diff --git a/source/compiler-core/slang-artifact.h b/source/compiler-core/slang-artifact.h index 400c85b2eb..6d65aafba4 100644 --- a/source/compiler-core/slang-artifact.h +++ b/source/compiler-core/slang-artifact.h @@ -143,6 +143,7 @@ enum class ArtifactPayload : uint8_t CUDA, ///< CUDA source Metal, ///< Metal source Slang, ///< Slang source + WGSL, ///< WGSL source KernelLike, ///< GPU Kernel like diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp index 9fa91abf6c..9f9deb92c8 100644 --- a/source/core/slang-type-text-util.cpp +++ b/source/core/slang-type-text-util.cpp @@ -63,6 +63,7 @@ static const TypeTextUtil::CompileTargetInfo s_compileTargetInfos[] = { SLANG_METAL, "metal", "metal", "Metal shader source" }, { SLANG_METAL_LIB, "metallib", "metallib", "Metal Library Bytecode" }, { SLANG_METAL_LIB_ASM, "metallib-asm" "metallib-asm", "Metal Library Bytecode assembly" }, + { SLANG_WGSL, "wgsl", "wgsl", "WebGPU shading language source" }, }; static const NamesDescriptionValue s_languageInfos[] = diff --git a/source/slang-record-replay/util/emum-to-string.h b/source/slang-record-replay/util/emum-to-string.h index 7a79525552..7226edc04c 100644 --- a/source/slang-record-replay/util/emum-to-string.h +++ b/source/slang-record-replay/util/emum-to-string.h @@ -34,6 +34,7 @@ namespace SlangRecord CASE(SLANG_METAL_LIB); CASE(SLANG_METAL_LIB_ASM); CASE(SLANG_HOST_SHARED_LIBRARY); + CASE(SLANG_WGSL); CASE(SLANG_TARGET_COUNT_OF); default: Slang::StringBuilder str; diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 10a6254c15..a8241bf736 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -5668,7 +5668,7 @@ vector acosh(vector x) // Test if all components are non-zero (HLSL SM 1.0) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool all(T x) { __target_switch @@ -5679,6 +5679,8 @@ bool all(T x) __intrinsic_asm "all"; case metal: __intrinsic_asm "all"; + case wgsl: + __intrinsic_asm "all"; case spirv: let zero = __default(); if (__isInt()) @@ -5806,7 +5808,7 @@ int3 WorkgroupSize(); __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool any(T x) { __target_switch @@ -5817,6 +5819,8 @@ bool any(T x) __intrinsic_asm "any"; case metal: __intrinsic_asm "any"; + case wgsl: + __intrinsic_asm "any"; case spirv: let zero = __default(); if (__isInt()) @@ -6142,7 +6146,7 @@ vector asinh(vector x) // Reinterpret bits as an int (HLSL SM 4.0) [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] int asint(float x) { __target_switch @@ -6152,6 +6156,7 @@ int asint(float x) case glsl: __intrinsic_asm "floatBitsToInt"; case hlsl: __intrinsic_asm "asint"; case metal: __intrinsic_asm "as_type<$TR>($0)"; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; case spirv: return spirv_asm { OpBitcast $$int result $x }; @@ -6285,7 +6290,7 @@ void asuint(double value, out uint lowbits, out uint highbits) // Reinterpret bits as a uint (HLSL SM 4.0) [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] uint asuint(float x) { __target_switch @@ -6295,6 +6300,7 @@ uint asuint(float x) case glsl: __intrinsic_asm "floatBitsToUint"; case hlsl: __intrinsic_asm "asuint"; case metal: __intrinsic_asm "as_type<$TR>($0)"; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; case spirv: return spirv_asm { OpBitcast $$uint result $x }; @@ -7025,7 +7031,7 @@ void clip(matrix x) // Cosine __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T cos(T x) { __target_switch @@ -7035,6 +7041,7 @@ T cos(T x) case glsl: __intrinsic_asm "cos"; case hlsl: __intrinsic_asm "cos"; case metal: __intrinsic_asm "cos"; + case wgsl: __intrinsic_asm "cos"; case spirv: return spirv_asm { OpExtInst $$T result glsl450 Cos $x }; @@ -10427,7 +10434,7 @@ matrix mad(matrix mvalue, matrix avalue, matrix [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T max(T x, T y) { // Note: a stdlib implementation of `max` (or `min`) will require splitting @@ -10440,6 +10447,7 @@ T max(T x, T y) case hlsl: __intrinsic_asm "max"; case glsl: __intrinsic_asm "max"; case metal: __intrinsic_asm "max"; + case wgsl: __intrinsic_asm "max"; case cuda: __intrinsic_asm "$P_max($0, $1)"; case cpp: __intrinsic_asm "$P_max($0, $1)"; case spirv: @@ -10656,7 +10664,7 @@ vector fmax3(vector x, vector y, vector z) // minimum __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T min(T x, T y) { __target_switch @@ -10664,6 +10672,7 @@ T min(T x, T y) case hlsl: case glsl: case metal: + case wgsl: __intrinsic_asm "min"; case cuda: case cpp: @@ -11103,13 +11112,14 @@ T mul(vector x, vector y) // vector-matrix __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector mul(vector left, matrix right) { __target_switch { case glsl: __intrinsic_asm "($1 * $0)"; case metal: __intrinsic_asm "($1 * $0)"; + case wgsl: __intrinsic_asm "($1 * $0)"; case hlsl: __intrinsic_asm "mul"; case spirv: return spirv_asm { OpMatrixTimesVector $$vector result $right $left @@ -12166,7 +12176,7 @@ matrix sign(matrix x) __generic [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T sin(T x) { __target_switch @@ -12176,6 +12186,7 @@ T sin(T x) case glsl: __intrinsic_asm "sin"; case hlsl: __intrinsic_asm "sin"; case metal: __intrinsic_asm "sin"; + case wgsl: __intrinsic_asm "sin"; case spirv: return spirv_asm { OpExtInst $$T result glsl450 Sin $x }; diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index a173a332f4..9e9b941513 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -111,6 +111,10 @@ def metal : target + textualTarget; /// [Target] def spirv : target; +/// Represents the WebGPU shading language code generation target. +/// [Target] +def wgsl : target + textualTarget; + // Capabilities that stand for target SPIR-V versions for the GLSL backend. // These are not compilation targets. We will convert `_spirv_*` to `glsl_spirv_*` during compilation. @@ -228,15 +232,15 @@ def _cuda_sm_9_0 : _cuda_sm_8_0; /// All code-gen targets /// [Compound] -alias any_target = hlsl | metal | glsl | c | cpp | cuda | spirv; +alias any_target = hlsl | metal | glsl | c | cpp | cuda | spirv | wgsl; /// All non-asm code-gen targets /// [Compound] -alias any_textual_target = hlsl | metal | glsl | c | cpp | cuda; +alias any_textual_target = hlsl | metal | glsl | c | cpp | cuda | wgsl; /// All slang-gfx compatible code-gen targets /// [Compound] -alias any_gfx_target = hlsl | metal | glsl | spirv; +alias any_gfx_target = hlsl | metal | glsl | spirv | wgsl; /// All "cpp syntax" code-gen targets /// [Compound] @@ -266,6 +270,10 @@ alias cpp_cuda_glsl_hlsl_spirv = cpp | cuda | glsl | hlsl | spirv; /// [Compound] alias cpp_cuda_glsl_hlsl_metal_spirv = cpp | cuda | glsl | hlsl | metal | spirv; +/// CPP, CUDA, GLSL, HLSL, Metal, SPIRV and WGSL code-gen targets +/// [Compound] +alias cpp_cuda_glsl_hlsl_metal_spirv_wgsl = cpp | cuda | glsl | hlsl | metal | spirv | wgsl; + /// CPP, CUDA, and HLSL code-gen targets /// [Compound] alias cpp_cuda_hlsl = cpp | cuda | hlsl; @@ -1178,6 +1186,7 @@ alias sm_4_0_version = _sm_4_0 | spirv_1_0 | _cuda_sm_2_0 | metal + | wgsl | cpp ; diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 4bb420fa7a..541085b4ee 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -1715,6 +1715,7 @@ namespace Slang case CodeGenTarget::PyTorchCppBinding: case CodeGenTarget::CSource: case CodeGenTarget::Metal: + case CodeGenTarget::WGSL: { RefPtr extensionTracker = _newExtensionTracker(target); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index b8ee4dc9cd..62e4c5f4a8 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -94,6 +94,7 @@ namespace Slang Metal = SLANG_METAL, MetalLib = SLANG_METAL_LIB, MetalLibAssembly = SLANG_METAL_LIB_ASM, + WGSL = SLANG_WGSL, CountOf = SLANG_TARGET_COUNT_OF, }; diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp index c9dd6d9c8a..e32a738c7f 100644 --- a/source/slang/slang-doc-markdown-writer.cpp +++ b/source/slang/slang-doc-markdown-writer.cpp @@ -483,6 +483,10 @@ static DocMarkdownWriter::Requirement _getRequirementFromTargetToken(const Token { return Requirement{ CodeGenTarget::Metal, targetName }; } + else if (isCapabilityDerivedFrom(targetCap, CapabilityAtom::wgsl)) + { + return Requirement{ CodeGenTarget::WGSL, targetName }; + } return Requirement{ CodeGenTarget::Unknown, String() }; } diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 1893929f89..caf3613a71 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -95,6 +95,10 @@ struct CLikeSourceEmitter::ComputeEmitActionsContext { return SourceLanguage::Metal; } + case CodeGenTarget::WGSL: + { + return SourceLanguage::WGSL; + } } } @@ -151,7 +155,7 @@ void CLikeSourceEmitter::ensureTypePrelude(IRType* type) } } -void CLikeSourceEmitter::emitDeclarator(DeclaratorInfo* declarator) +void CLikeSourceEmitter::emitDeclaratorImpl(DeclaratorInfo* declarator) { if (!declarator) return; @@ -341,13 +345,18 @@ void CLikeSourceEmitter::_emitPostfixTypeAttr(IRAttr* attr) // we may need to handle it here. } +void CLikeSourceEmitter::emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator) +{ + emitSimpleType(type); + emitDeclarator(declarator); +} + void CLikeSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator) { switch (type->getOp()) { default: - emitSimpleType(type); - emitDeclarator(declarator); + emitSimpleTypeAndDeclarator(type, declarator); break; case kIROp_RateQualifiedType: @@ -648,7 +657,7 @@ bool CLikeSourceEmitter::maybeEmitParens(EmitOpInfo& outerPrec, const EmitOpInfo bool needParens = (prec.leftPrecedence <= outerPrec.leftPrecedence) || (prec.rightPrecedence <= outerPrec.rightPrecedence); - // While Slang correctly removes some of parentheses, DXC prints warnings + // While Slang correctly removes some of parentheses, many compilers print warnings // for common mistakes when parentheses are not used with certain combinations // of the operations. We emit parentheses to avoid the warnings. // @@ -676,6 +685,12 @@ bool CLikeSourceEmitter::maybeEmitParens(EmitOpInfo& outerPrec, const EmitOpInfo { needParens = true; } + // a + b & c => (a + b) & c + else if (prec.rightPrecedence == EPrecedence::kEPrecedence_Additive_Right + && outerPrec.rightPrecedence == EPrecedence::kEPrecedence_BitAnd_Left) + { + needParens = true; + } if (needParens) { @@ -1657,11 +1672,16 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) return true; } +bool CLikeSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* /* inst */) +{ + return doesTargetSupportPtrTypes(); +} + void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const& outerPrec) { EmitOpInfo newOuterPrec = outerPrec; - if (doesTargetSupportPtrTypes()) + if (isPointerSyntaxRequiredImpl(inst)) { switch (inst->getOp()) { @@ -1760,7 +1780,7 @@ void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const& void CLikeSourceEmitter::emitVarExpr(IRInst* inst, EmitOpInfo const& outerPrec) { - if (doesTargetSupportPtrTypes()) + if (isPointerSyntaxRequiredImpl(inst)) { auto prec = getInfo(EmitOp::Prefix); auto newOuterPrec = outerPrec; @@ -1842,7 +1862,8 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) emitRateQualifiers(inst); - if(as(inst->getParent())) + bool isConstant(as(inst->getParent())); + if(isConstant) { // "Ordinary" instructions at module scope are constants @@ -1857,6 +1878,9 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) case SourceLanguage::Metal: m_writer->emit("constant "); break; + case SourceLanguage::WGSL: + // This is handled by emitVarKeyword, below + break; default: m_writer->emit("const "); break; @@ -1864,6 +1888,8 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) } + emitVarKeyword(type, isConstant); + emitType(type, getName(inst)); m_writer->emit(" = "); } @@ -2297,7 +2323,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO IRFieldAddress* ii = (IRFieldAddress*) inst; - if (doesTargetSupportPtrTypes()) + if (isPointerSyntaxRequiredImpl(inst)) { auto prec = getInfo(EmitOp::Prefix); needClose = maybeEmitParens(outerPrec, prec); @@ -3117,6 +3143,8 @@ void CLikeSourceEmitter::_emitStoreImpl(IRStore* store) void CLikeSourceEmitter::_emitInstAsDefaultInitializedVar(IRInst* inst, IRType* type) { + emitVarKeyword(type, /* isConstant */ false); + emitType(type, getName(inst)); // On targets that support empty initializers, we will emit it. @@ -3178,6 +3206,20 @@ void CLikeSourceEmitter::emitLayoutSemantics(IRInst* inst, char const* uniformSe emitLayoutSemanticsImpl(inst, uniformSemanticSpelling, EmitLayoutSemanticOption::kPostType); } +void CLikeSourceEmitter::emitSwitchCaseSelectorsImpl(IRBasicType *const /* switchCondition */, const SwitchRegion::Case *const currentCase, const bool isDefault) +{ + for(auto caseVal : currentCase->values) + { + m_writer->emit("case "); + emitOperand(caseVal, getInfo(EmitOp::General)); + m_writer->emit(":\n"); + } + if(isDefault) + { + m_writer->emit("default:\n"); + } +} + void CLikeSourceEmitter::emitRegion(Region* inRegion) { // We will use a loop so that we can process sequential (simple) @@ -3333,17 +3375,9 @@ void CLikeSourceEmitter::emitRegion(Region* inRegion) auto defaultCase = switchRegion->defaultCase; for(auto currentCase : switchRegion->cases) { - for(auto caseVal : currentCase->values) - { - m_writer->emit("case "); - emitOperand(caseVal, getInfo(EmitOp::General)); - m_writer->emit(":\n"); - } - if(currentCase.Ptr() == defaultCase) - { - m_writer->emit("default:\n"); - } - + const bool isDefault {currentCase.Ptr() == defaultCase}; + IRBasicType *const switchConditionType {as(switchRegion->getCondition()->getDataType())}; + emitSwitchCaseSelectors(switchConditionType, currentCase.Ptr(), isDefault); m_writer->indent(); m_writer->emit("{\n"); m_writer->indent(); @@ -3449,9 +3483,16 @@ void CLikeSourceEmitter::emitSimpleFuncParamsImpl(IRFunc* func) m_writer->emit(")"); } -void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func) +void CLikeSourceEmitter::emitFuncHeaderImpl(IRFunc* func) { auto resultType = func->getResultType(); + auto name = getName(func); + emitType(resultType, name); + emitSimpleFuncParamsImpl(func); +} + +void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func) +{ // Deal with decorations that need // to be emitted as attributes @@ -3467,12 +3508,8 @@ void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func) emitFunctionPreambleImpl(func); - auto name = getName(func); - emitFuncDecorations(func); - - emitType(resultType, name); - emitSimpleFuncParamsImpl(func); + emitFuncHeader(func); emitSemantics(func); // TODO: encode declaration vs. definition @@ -3688,6 +3725,11 @@ void CLikeSourceEmitter::emitStruct(IRStructType* structType) m_writer->emit(";\n\n"); } +void CLikeSourceEmitter::emitStructDeclarationSeparatorImpl() +{ + m_writer->emit(";"); +} + void CLikeSourceEmitter::emitStructDeclarationsBlock(IRStructType* structType, bool allowOffsetLayout) { m_writer->emit("\n{\n"); @@ -3716,11 +3758,13 @@ void CLikeSourceEmitter::emitStructDeclarationsBlock(IRStructType* structType, b emitPackOffsetModifier(fieldKey, fieldType, packOffsetDecoration); } } + emitStructFieldAttributes(structType, ff); emitMemoryQualifiers(fieldKey); emitType(fieldType, getName(fieldKey)); emitSemantics(fieldKey, allowOffsetLayout); emitPostDeclarationAttributesForType(fieldType); - m_writer->emit(";\n"); + emitStructDeclarationSeparator(); + m_writer->emit("\n"); } m_writer->dedent(); @@ -3931,6 +3975,8 @@ void CLikeSourceEmitter::emitParameterGroup(IRGlobalParam* varDecl, IRUniformPar emitParameterGroupImpl(varDecl, type); } +void CLikeSourceEmitter::emitVarKeywordImpl(IRType * /* type */, bool /* isConstant */) {} + void CLikeSourceEmitter::emitVar(IRVar* varDecl) { auto allocatedType = varDecl->getDataType(); @@ -3969,6 +4015,8 @@ void CLikeSourceEmitter::emitVar(IRVar* varDecl) #endif emitRateQualifiersAndAddressSpace(varDecl); + emitVarKeyword(varType, /* isConstant */ false); + emitType(varType, getName(varDecl)); emitSemantics(varDecl); @@ -4099,6 +4147,7 @@ void CLikeSourceEmitter::emitGlobalVar(IRGlobalVar* varDecl) emitVarModifiers(layout, varDecl, varType); emitRateQualifiersAndAddressSpace(varDecl); + emitVarKeyword(varType, /* isConstant */ true); emitType(varType, getName(varDecl)); // TODO: These shouldn't be needed for ordinary @@ -4172,7 +4221,8 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl) emitDecorationLayoutSemantics(varDecl, "register"); emitRateQualifiersAndAddressSpace(varDecl); - emitType(varType, getName(varDecl)); + emitVarKeyword(varType, /* isConstant */ false); + emitGlobalParamType(varType, getName(varDecl)); emitSemantics(varDecl); diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 00ad156d1d..be769f31f9 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -248,7 +248,8 @@ class CLikeSourceEmitter: public SourceEmitterBase // void ensureTypePrelude(IRType* type); - void emitDeclarator(DeclaratorInfo* declarator); + void emitDeclarator(DeclaratorInfo* declarator) {emitDeclaratorImpl(declarator);} + virtual void emitDeclaratorImpl(DeclaratorInfo* declarator); void emitType(IRType* type, const StringSliceLoc* nameLoc) { emitTypeImpl(type, nameLoc); } void emitType(IRType* type, Name* name); @@ -256,6 +257,7 @@ class CLikeSourceEmitter: public SourceEmitterBase void emitType(IRType* type); void emitType(IRType* type, Name* name, SourceLoc const& nameLoc); void emitType(IRType* type, NameLoc const& nameAndLoc); + virtual void emitGlobalParamType(IRType* type, String const& name) {emitType(type, name);} bool hasExplicitConstantBufferOffset(IRInst* cbufferType); bool isSingleElementConstantBuffer(IRInst* cbufferType); bool shouldForceUnpackConstantBufferElements(IRInst* cbufferType); @@ -368,8 +370,11 @@ class CLikeSourceEmitter: public SourceEmitterBase /// Emit high-level statements for the body of a function. void emitFunctionBody(IRGlobalValueWithCode* code); + void emitFuncHeader(IRFunc* func) { emitFuncHeaderImpl(func); } void emitSimpleFunc(IRFunc* func) { emitSimpleFuncImpl(func); } + void emitSwitchCaseSelectors(IRBasicType *const switchConditionType, const SwitchRegion::Case *const currentCase, const bool isDefault) {emitSwitchCaseSelectorsImpl(switchConditionType, currentCase, isDefault);} + void emitParamType(IRType* type, String const& name) { emitParamTypeImpl(type, name); } void emitFuncDecl(IRFunc* func); @@ -394,10 +399,14 @@ class CLikeSourceEmitter: public SourceEmitterBase void emitStructDeclarationsBlock(IRStructType* structType, bool allowOffsetLayout); void emitClass(IRClassType* structType); + void emitStructDeclarationSeparator() {emitStructDeclarationSeparatorImpl();} + virtual void emitStructDeclarationSeparatorImpl(); + /// Emit type attributes that should appear after, e.g., a `struct` keyword void emitPostKeywordTypeAttributes(IRInst* inst) { emitPostKeywordTypeAttributesImpl(inst); } virtual void emitMemoryQualifiers(IRInst* /*varInst*/) {}; + virtual void emitStructFieldAttributes(IRStructType * /* structType */, IRStructField * /* field */) {}; void emitInterpolationModifiers(IRInst* varInst, IRType* valueType, IRVarLayout* layout); void emitMeshShaderModifiers(IRInst* varInst); virtual void emitPackOffsetModifier(IRInst* /*varInst*/, IRType* /*valueType*/, IRPackOffsetDecoration* /*decoration*/) {}; @@ -421,6 +430,7 @@ class CLikeSourceEmitter: public SourceEmitterBase void emitGlobalInst(IRInst* inst); virtual void emitGlobalInstImpl(IRInst* inst); + virtual bool isPointerSyntaxRequiredImpl(IRInst* inst); void ensureInstOperand(ComputeEmitActionsContext* ctx, IRInst* inst, EmitAction::Level requiredLevel = EmitAction::Level::Definition); @@ -486,6 +496,11 @@ class CLikeSourceEmitter: public SourceEmitterBase virtual void emitPreModuleImpl(); virtual void emitPostModuleImpl(); + virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator); + void emitSimpleTypeAndDeclarator(IRType* type, DeclaratorInfo* declarator) {emitSimpleTypeAndDeclaratorImpl(type, declarator);}; + virtual void emitVarKeywordImpl(IRType * type, bool isConstant); + void emitVarKeyword(IRType * type, bool isConstant) {emitVarKeywordImpl(type, isConstant);} + virtual void beforeComputeEmitActions(IRModule* module) { SLANG_UNUSED(module); }; virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) { SLANG_UNUSED(rate); SLANG_UNUSED(addressSpace); } @@ -501,6 +516,7 @@ class CLikeSourceEmitter: public SourceEmitterBase virtual void emitTypeImpl(IRType* type, const StringSliceLoc* nameLoc); virtual void emitSimpleValueImpl(IRInst* inst); virtual void emitModuleImpl(IRModule* module, DiagnosticSink* sink); + virtual void emitFuncHeaderImpl(IRFunc* func); virtual void emitSimpleFuncImpl(IRFunc* func); virtual void emitVarExpr(IRInst* inst, EmitOpInfo const& outerPrec); virtual void emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPrec); @@ -511,6 +527,7 @@ class CLikeSourceEmitter: public SourceEmitterBase virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) { SLANG_UNUSED(decl); } virtual void emitIfDecorationsImpl(IRIfElse* ifInst) { SLANG_UNUSED(ifInst); } virtual void emitSwitchDecorationsImpl(IRSwitch* switchInst) { SLANG_UNUSED(switchInst); } + virtual void emitSwitchCaseSelectorsImpl(IRBasicType *const switchConditionType, const SwitchRegion::Case *const currentCase, const bool isDefault); virtual void emitFuncDecorationImpl(IRDecoration* decoration) { SLANG_UNUSED(decoration); } virtual void emitLivenessImpl(IRInst* inst); diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp new file mode 100644 index 0000000000..0a4cca407a --- /dev/null +++ b/source/slang/slang-emit-wgsl.cpp @@ -0,0 +1,1005 @@ +#include "slang-emit-wgsl.h" + +// A note on row/column "terminology reversal". +// +// This is an "terminology reversing" implementation in the sense that +// * "column" in Slang code maps to "row" in the generated WGSL code, and +// * "row" in Slang code maps to "column" in the generated WGSL code. +// +// This means that matrices in Slang code end up getting translated to +// matrices that actually represent the transpose of what the Slang matrix +// represented. +// Both API's adopt the standard matrix multiplication convention whereby the +// column count of the matrix on the left hand side needs to match row count of +// the matrix on the right hand side. +// For these reasons, and due to the fact that (M_1 ... M_n)^T = M_n^T ... M_1^T, +// the order of matrix (and vector-matrix products) products must also reversed +// in the WGSL code. +// +// This may lead to confusion (which is why this note is referenced in several +// places), but the benefit of doing this is that the generated WGSL code is +// simpler to generate and should be faster to compile. +// A "terminology preserving" implementation would have to generate lots of +// 'transpose' calls, or else perform more complicated transformations that +// end up duplicating expressions many times. + +namespace Slang { + +void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl( + IRBasicType *const switchConditionType, + const SwitchRegion::Case *const currentCase, const bool isDefault + ) +{ + // WGSL has special syntax for blocks sharing case labels: + // "case 2, 3, 4: ...;" instead of the C-like syntax + // "case 2: case 3: case 4: ...;". + + m_writer->emit("case "); + for (auto caseVal : currentCase->values) + { + // TODO: Fix this in the front-end [1], remove the if-path and just do the else-path. + // We can't do that at the moment because it would break Falcor [2]. + // [1] https://github.com/shader-slang/slang/pull/5025/commits/a32156ef52f43b8503b2c77f2f1d51220ab9bdea + // [2] https://github.com/shader-slang/slang/pull/5025#issuecomment-2334495120 + if (caseVal->getOp() == kIROp_IntLit) + { + auto caseLitInst = static_cast(caseVal); + IRBasicType *const caseInstType = as(caseLitInst->getDataType()); + // WGSL doesn't allow switch condition and case type mismatches, see [1]. + // Thus we need to insert explicit conversions. + // Doing a wrapping cast will match Slang's de facto semantics, according to + // [2]. + // (This is just a bitcast, assuming a two's complement representation.) + // [1] https://www.w3.org/TR/WGSL/#switch-statement + // [2] https://github.com/shader-slang/slang/issues/4921 + const bool needBitcast = + caseInstType->getBaseType() != switchConditionType->getBaseType(); + if (needBitcast) + { + m_writer->emit("bitcast<"); + emitType(switchConditionType); + m_writer->emit(">("); + } + emitOperand(caseVal, getInfo(EmitOp::General)); + if (needBitcast) + { + m_writer->emit(")"); + } + } + else + { + emitOperand(caseVal, getInfo(EmitOp::General)); + } + m_writer->emit(", "); + } + if (isDefault) + { + m_writer->emit("default, "); + } + m_writer->emit(":\n"); +} + +void WGSLSourceEmitter::emitParameterGroupImpl( + IRGlobalParam* varDecl, IRUniformParameterGroupType* type +) +{ + auto varLayout = getVarLayout(varDecl); + SLANG_RELEASE_ASSERT(varLayout); + + for (auto attr : varLayout->getOffsetAttrs()) + { + + const LayoutResourceKind kind = attr->getResourceKind(); + switch (kind) + { + case LayoutResourceKind::VaryingInput: + case LayoutResourceKind::VaryingOutput: + m_writer->emit("@location("); + m_writer->emit(attr->getOffset()); + m_writer->emit(")"); + if (attr->getSpace()) + { + // TODO: Not sure what 'space' should map to in WGSL + SLANG_ASSERT(false); + } + break; + + case LayoutResourceKind::SpecializationConstant: + // TODO: + // Consider moving to a differently named function. + // This is not technically an attribute, but a declaration. + // + // https://www.w3.org/TR/WGSL/#override-decls + m_writer->emit("override"); + break; + + case LayoutResourceKind::Uniform: + case LayoutResourceKind::ConstantBuffer: + case LayoutResourceKind::ShaderResource: + case LayoutResourceKind::UnorderedAccess: + case LayoutResourceKind::SamplerState: + case LayoutResourceKind::DescriptorTableSlot: + m_writer->emit("@binding("); + m_writer->emit(attr->getOffset()); + m_writer->emit(") "); + m_writer->emit("@group("); + m_writer->emit(attr->getSpace()); + m_writer->emit(") "); + break; + + } + + } + + auto elementType = type->getElementType(); + m_writer->emit("var "); + m_writer->emit(getName(varDecl)); + m_writer->emit(" : "); + emitType(elementType); + m_writer->emit(";\n"); +} + +void WGSLSourceEmitter::emitEntryPointAttributesImpl( + IRFunc* irFunc, IREntryPointDecoration* entryPointDecor + ) +{ + auto stage = entryPointDecor->getProfile().getStage(); + + switch (stage) + { + + case Stage::Fragment: + m_writer->emit("@fragment\n"); + break; + case Stage::Vertex: + m_writer->emit("@vertex\n"); + break; + + case Stage::Compute: + { + m_writer->emit("@compute\n"); + + { + Int sizeAlongAxis[kThreadGroupAxisCount]; + getComputeThreadGroupSize(irFunc, sizeAlongAxis); + + m_writer->emit("@workgroup_size("); + for (int ii = 0; ii < kThreadGroupAxisCount; ++ii) + { + if (ii != 0) + m_writer->emit(", "); + m_writer->emit(sizeAlongAxis[ii]); + } + m_writer->emit(")\n"); + } + } + break; + + default: + SLANG_ABORT_COMPILATION("unsupported stage."); + } + +} + +// This is 'function_header' from the WGSL specification +void WGSLSourceEmitter::emitFuncHeaderImpl(IRFunc* func) +{ + Slang::IRType * resultType = func->getResultType(); + auto name = getName(func); + + m_writer->emit("fn "); + m_writer->emit(name); + + emitSimpleFuncParamsImpl(func); + + // An absence of return type is expressed by skipping the optional '->' part of the + // header. + if (resultType->getOp() != kIROp_VoidType) + { + m_writer->emit(" -> "); + emitType(resultType); + } +} + +void WGSLSourceEmitter::emitSimpleFuncParamImpl(IRParam* param) +{ + if (auto sysSemanticDecor = param->findDecoration()) + { + m_writer->emit("@builtin("); + m_writer->emit(sysSemanticDecor->getSemantic()); + m_writer->emit(")"); + } + + CLikeSourceEmitter::emitSimpleFuncParamImpl(param); +} + +void WGSLSourceEmitter::emitMatrixType( + IRType *const elementType, const IRIntegerValue& rowCountWGSL, + const IRIntegerValue& colCountWGSL + ) +{ + // WGSL uses CxR convention + m_writer->emit("mat"); + m_writer->emit(colCountWGSL); + m_writer->emit("x"); + m_writer->emit(rowCountWGSL); + m_writer->emit("<"); + emitType(elementType); + m_writer->emit(">"); +} + +void WGSLSourceEmitter::emitStructDeclarationSeparatorImpl() +{ + m_writer->emit(","); +} + +static bool isPowerOf2(const uint32_t n) +{ + return (n != 0U) && ((n - 1U) & n) == 0U; +} + +void WGSLSourceEmitter::emitStructFieldAttributes( + IRStructType * structType, IRStructField * field + ) +{ + // Tint emits errors unless we explicitly spell out the layout in some cases, so emit + // offset and align attribtues for all fields. + IRSizeAndAlignmentDecoration *const sizeAndAlignmentDecoration = + structType->findDecoration(); + // NullDifferential struct doesn't have size and alignment decoration + if (sizeAndAlignmentDecoration == nullptr) + return; + SLANG_ASSERT(sizeAndAlignmentDecoration->getAlignment() > IRIntegerValue{0}); + SLANG_ASSERT( + sizeAndAlignmentDecoration->getAlignment() <= IRIntegerValue{UINT32_MAX} + ); + const uint32_t structAlignment = + static_cast(sizeAndAlignmentDecoration->getAlignment()); + IROffsetDecoration *const fieldOffsetDecoration = + field->findDecoration(); + SLANG_ASSERT(fieldOffsetDecoration->getOffset() >= IRIntegerValue{0}); + SLANG_ASSERT(fieldOffsetDecoration->getOffset() <= IRIntegerValue{UINT32_MAX}); + SLANG_ASSERT(isPowerOf2(structAlignment)); + const uint32_t fieldOffset = + static_cast(fieldOffsetDecoration->getOffset()); + // Alignment is GCD(fieldOffset, structAlignment) + // TODO: Use builtin/intrinsic (e.g. __builtin_ffs) + uint32_t fieldAlignment = 1U; + while (((fieldAlignment & (structAlignment | fieldOffset)) == 0U)) + fieldAlignment = fieldAlignment << 1U; + + m_writer->emit("@align("); + m_writer->emit(fieldAlignment); + m_writer->emit(")"); +} + +bool WGSLSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* inst) +{ + // Structured buffers are mapped to 'array' types, which don't need dereferencing + if (inst->getOp() == kIROp_RWStructuredBufferGetElementPtr) + return false; + + // Don't emit "->" to access fields in resource structs + if (inst->getOp() == kIROp_FieldAddress) + return false; + + // Don't emit "*" to access fields in resource structs + if (inst->getOp() == kIROp_GlobalParam) + return false; + + // Emit 'globalVar' instead of "*&globalVar" + if (inst->getOp() == kIROp_GlobalVar) + return false; + + return true; +} + +void WGSLSourceEmitter::emit(const AddressSpace addressSpace) +{ + switch (addressSpace) + { + case AddressSpace::Uniform: + m_writer->emit("uniform"); + break; + + case AddressSpace::StorageBuffer: + m_writer->emit("storage"); + break; + + case AddressSpace::Generic: + m_writer->emit("function"); + break; + + case AddressSpace::ThreadLocal: + m_writer->emit("private"); + break; + + case AddressSpace::GroupShared: + m_writer->emit("workgroup"); + break; + } +} + +void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type) +{ + switch (type->getOp()) + { + + case kIROp_HLSLRWStructuredBufferType: + { + auto structuredBufferType = as(type); + m_writer->emit("ptr<"); + emit(AddressSpace::StorageBuffer); + m_writer->emit(", "); + m_writer->emit("array"); + m_writer->emit("<"); + emitType(structuredBufferType->getElementType()); + m_writer->emit(">"); + m_writer->emit(", read_write"); + m_writer->emit(">"); + } + break; + + case kIROp_HLSLStructuredBufferType: + { + auto structuredBufferType = as(type); + m_writer->emit("ptr<"); + emit(AddressSpace::StorageBuffer); + m_writer->emit(", "); + m_writer->emit("array"); + m_writer->emit("<"); + emitType(structuredBufferType->getElementType()); + m_writer->emit(">"); + m_writer->emit(", read"); + m_writer->emit(">"); + } + break; + + case kIROp_VoidType: + { + // There is no void type in WGSL. + // A return type of "void" is expressed by skipping the end part of the + // 'function_header' term: + // " + // function_header : + // 'fn' ident '(' param_list ? ')' + // ( '->' attribute * template_elaborated_ident ) ? + // " + // In other words, in WGSL we should never even get to the point where we're + // asking to emit 'void'. + SLANG_UNEXPECTED("'void' type emitted"); + return; + } + + case kIROp_FloatType: + m_writer->emit("f32"); + break; + case kIROp_DoubleType: + // There is no "f64" type in WGSL + SLANG_UNEXPECTED("'double' type emitted"); + break; + case kIROp_Int8Type: + case kIROp_UInt8Type: + // There is no "[i|u]8" type in WGSL + SLANG_UNEXPECTED("8 bit integer type emitted"); + break; + case kIROp_HalfType: + m_f16ExtensionEnabled = true; + m_writer->emit("f16"); + break; + case kIROp_BoolType: + m_writer->emit("bool"); + break; + case kIROp_IntType: + m_writer->emit("i32"); + break; + case kIROp_UIntType: + m_writer->emit("u32"); + break; + case kIROp_UInt64Type: + { + m_writer->emit(getDefaultBuiltinTypeName(type->getOp())); + return; + } + case kIROp_Int16Type: + case kIROp_UInt16Type: + SLANG_UNEXPECTED("16 bit integer value emitted"); + return; + case kIROp_Int64Type: + case kIROp_IntPtrType: + m_writer->emit("i64"); + return; + case kIROp_UIntPtrType: + m_writer->emit("u64"); + return; + case kIROp_StructType: + m_writer->emit(getName(type)); + return; + + case kIROp_VectorType: + { + auto vecType = (IRVectorType*)type; + emitVectorTypeNameImpl( + vecType->getElementType(), getIntVal(vecType->getElementCount()) + ); + return; + } + case kIROp_MatrixType: + { + auto matType = (IRMatrixType*)type; + // We map matrices in Slang to WGSL matrices that represent the transpose. + // (See note on "terminology reversal".) + const IRIntegerValue colCountWGSL = getIntVal(matType->getRowCount()); + const IRIntegerValue rowCountWGSL = getIntVal(matType->getColumnCount()); + emitMatrixType(matType->getElementType(), rowCountWGSL, colCountWGSL); + return; + } + case kIROp_SamplerStateType: + { + m_writer->emit("sampler"); + return; + } + + case kIROp_SamplerComparisonStateType: + { + m_writer->emit("sampler_comparison"); + return; + } + + case kIROp_PtrType: + case kIROp_InOutType: + case kIROp_OutType: + case kIROp_RefType: + case kIROp_ConstRefType: + { + auto ptrType = cast(type); + m_writer->emit("ptr<"); + emit((AddressSpace)ptrType->getAddressSpace()); + m_writer->emit(", "); + emitType((IRType*)ptrType->getValueType()); + m_writer->emit(">"); + return; + } + + case kIROp_ArrayType: + { + m_writer->emit("array<"); + emitType((IRType*)type->getOperand(0)); + m_writer->emit(", "); + emitVal(type->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(">"); + return; + } + default: + break; + + } + +} + +void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout) +{ + + for (auto attr : layout->getOffsetAttrs()) + { + LayoutResourceKind kind = attr->getResourceKind(); + + // TODO: + // This is not correct. For the moment this is just here as a hack to make + // @binding and @group unique, so that we can pass WGSL compile tests. + // This will have to be revisited when we actually want to supply resources to + // shaders. + if (kind == LayoutResourceKind::DescriptorTableSlot) + { + m_writer->emit("@binding("); + m_writer->emit(attr->getOffset()); + m_writer->emit(") "); + m_writer->emit("@group("); + m_writer->emit(attr->getSpace()); + m_writer->emit(") "); + + return; + } + } + +} + +void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, const bool isConstant) +{ + if (isConstant) + m_writer->emit("const"); + else + m_writer->emit("var"); + if (type->getOp() == kIROp_HLSLRWStructuredBufferType) + { + m_writer->emit("<"); + m_writer->emit("storage, read_write"); + m_writer->emit(">"); + } + else if (type->getOp() == kIROp_HLSLStructuredBufferType) + { + m_writer->emit("<"); + m_writer->emit("storage, read"); + m_writer->emit(">"); + } +} + +void WGSLSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator) +{ + // C-like languages bake array-ness, pointer-ness and reference-ness into the + // declarator, which happens in the default _emitType implementation. + // WGSL on the other hand, don't have special syntax -- these are just types. + switch (type->getOp()) + { + case kIROp_ArrayType: + case kIROp_AttributedType: + case kIROp_UnsizedArrayType: + emitSimpleTypeAndDeclarator(type, declarator); + break; + default: + CLikeSourceEmitter::_emitType(type, declarator); + break; + } +} + +void WGSLSourceEmitter::emitDeclaratorImpl(DeclaratorInfo* declarator) +{ + if (!declarator) return; + + m_writer->emit(" "); + + switch (declarator->flavor) + { + case DeclaratorInfo::Flavor::Name: + { + auto nameDeclarator = (NameDeclaratorInfo*)declarator; + m_writer->emitName(*nameDeclarator->nameAndLoc); + } + break; + + case DeclaratorInfo::Flavor::SizedArray: + { + // Sized arrays are just types (array) in WGSL -- they are not + // supported at the syntax level + // https://www.w3.org/TR/WGSL/#array + SLANG_UNEXPECTED("Sized array declarator"); + } + break; + + case DeclaratorInfo::Flavor::UnsizedArray: + { + // Unsized arrays are just types (array) in WGSL -- they are not + // supported at the syntax level + // https://www.w3.org/TR/WGSL/#array + SLANG_UNEXPECTED("Unsized array declarator"); + } + break; + + case DeclaratorInfo::Flavor::Ptr: + { + // Pointers (ptr) are just types in WGSL -- they are not supported at + // the syntax level + // https://www.w3.org/TR/WGSL/#ref-ptr-types + SLANG_UNEXPECTED("Pointer declarator"); + } + break; + + case DeclaratorInfo::Flavor::Ref: + { + // References (ref) are just types in WGSL -- they are not supported + // at the syntax level + // https://www.w3.org/TR/WGSL/#ref-ptr-types + SLANG_UNEXPECTED("Reference declarator"); + } + break; + + case DeclaratorInfo::Flavor::LiteralSizedArray: + { + // Sized arrays are just types (array) in WGSL -- they are not supported + // at the syntax level + // https://www.w3.org/TR/WGSL/#array + SLANG_UNEXPECTED("Literal-sized array declarator"); + } + break; + + case DeclaratorInfo::Flavor::Attributed: + { + auto attributedDeclarator = (AttributedDeclaratorInfo*)declarator; + auto instWithAttributes = attributedDeclarator->instWithAttributes; + for (auto attr : instWithAttributes->getAllAttrs()) + { + _emitPostfixTypeAttr(attr); + } + emitDeclarator(attributedDeclarator->next); + } + break; + + default: + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unknown declarator flavor"); + break; + } +} + +void WGSLSourceEmitter::emitSimpleTypeAndDeclaratorImpl( + IRType* type, DeclaratorInfo* declarator + ) +{ + if (declarator) + { + emitDeclarator(declarator); + m_writer->emit(" : "); + } + emitSimpleType(type); +} + +void WGSLSourceEmitter::emitSimpleValueImpl(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_IntLit: + { + auto litInst = static_cast(inst); + + IRBasicType* type = as(inst->getDataType()); + if (type) + { + switch (type->getBaseType()) + { + default: + + case BaseType::Int8: + case BaseType::UInt8: + { + SLANG_UNEXPECTED("8 bit integer value emitted"); + break; + } + case BaseType::Int16: + case BaseType::UInt16: + { + SLANG_UNEXPECTED("16 bit integer value emitted"); + break; + } + case BaseType::Int: + { + m_writer->emit("i32("); + m_writer->emit(int32_t(litInst->value.intVal)); + m_writer->emit(")"); + return; + } + case BaseType::UInt: + { + m_writer->emit("u32("); + m_writer->emit(UInt(uint32_t(litInst->value.intVal))); + m_writer->emit(")"); + break; + } + case BaseType::Int64: + { + m_writer->emit("i64("); + m_writer->emitInt64(int64_t(litInst->value.intVal)); + m_writer->emit(")"); + break; + } + case BaseType::UInt64: + { + m_writer->emit("u64("); + SLANG_COMPILE_TIME_ASSERT( + sizeof(litInst->value.intVal) >= sizeof(uint64_t) + ); + m_writer->emitUInt64(uint64_t(litInst->value.intVal)); + m_writer->emit(")"); + break; + } + case BaseType::IntPtr: + { +#if SLANG_PTR_IS_64 + m_writer->emit("i64("); + m_writer->emitInt64(int64_t(litInst->value.intVal)); + m_writer->emit(")"); +#else + m_writer->emit("i32("); + m_writer->emit(int(litInst->value.intVal)); + m_writer->emit(")"); +#endif + break; + } + case BaseType::UIntPtr: + { +#if SLANG_PTR_IS_64 + m_writer->emit("u64("); + m_writer->emitUInt64(uint64_t(litInst->value.intVal)); + m_writer->emit(")"); +#else + m_writer->emit("u32("); + m_writer->emit(UInt(uint32_t(litInst->value.intVal))); + m_writer->emit(")"); +#endif + break; + } + + } + } + else + { + // If no type... just output what we have + m_writer->emit(litInst->value.intVal); + } + break; + } + + case kIROp_FloatLit: + { + auto litInst = static_cast(inst); + + IRBasicType* type = as(inst->getDataType()); + if (type) + { + switch (type->getBaseType()) + { + default: + + case BaseType::Half: + { + m_writer->emit(litInst->value.floatVal); + m_writer->emit("h"); + m_f16ExtensionEnabled = true; + } + break; + + case BaseType::Float: + { + m_writer->emit(litInst->value.floatVal); + m_writer->emit("f"); + } + break; + + case BaseType::Double: + { + // There is not "f64" in WGSL + SLANG_UNEXPECTED("'double' type emitted"); + } + break; + } + } + else + { + // If no type... just output what we have + m_writer->emit(litInst->value.floatVal); + } + } + break; + + case kIROp_BoolLit: + { + bool val = ((IRConstant*)inst)->value.intVal != 0; + m_writer->emit(val ? "true" : "false"); + } + break; + + default: + SLANG_UNIMPLEMENTED_X("val case for emit"); + break; + } + + +} + +void WGSLSourceEmitter::emitParamTypeImpl(IRType* type, const String& name) +{ + emitType(type, name); +} + +bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) +{ + EmitOpInfo outerPrec = inOuterPrec; + + switch (inst->getOp()) + { + + case kIROp_MakeVectorFromScalar: + { + // In WGSL this is done by calling the vec* overloads listed in [1] + // [1] https://www.w3.org/TR/WGSL/#value-constructor-builtin-function + emitType(inst->getDataType()); + m_writer->emit("("); + auto prec = getInfo(EmitOp::Prefix); + emitOperand(inst->getOperand(0), rightSide(outerPrec, prec)); + m_writer->emit(")"); + return true; + } + break; + + case kIROp_BitCast: + { + // In WGSL there is a built-in bitcast function! + // https://www.w3.org/TR/WGSL/#bitcast-builtin + m_writer->emit("bitcast"); + m_writer->emit("<"); + emitType(inst->getDataType()); + m_writer->emit(">"); + m_writer->emit("("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + break; + + case kIROp_MakeArray: + case kIROp_MakeStruct: + { + // It seems there are currently no designated initializers in WGSL. + // Similarly for array initializers. + // https://github.com/gpuweb/gpuweb/issues/4210 + + // There is a constructor named like the struct/array type itself + auto type = inst->getDataType(); + emitType(type); + m_writer->emit("( "); + UInt argCount = inst->getOperandCount(); + for (UInt aa = 0; aa < argCount; ++aa) + { + if (aa != 0) m_writer->emit(", "); + emitOperand(inst->getOperand(aa), getInfo(EmitOp::General)); + } + m_writer->emit(" )"); + + return true; + } + break; + + case kIROp_MakeArrayFromElement: + { + // It seems there are currently no array initializers in WGSL. + + // There is a constructor named like the array type itself + auto type = inst->getDataType(); + emitType(type); + m_writer->emit("("); + UInt argCount = + (UInt)cast( + cast(inst->getDataType())->getElementCount() + )->getValue(); + for (UInt aa = 0; aa < argCount; ++aa) + { + if (aa != 0) m_writer->emit(", "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + } + m_writer->emit(")"); + return true; + } + break; + + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + { + // Structured buffers are just arrays in WGSL + auto base = inst->getOperand(0); + emitOperand(base, outerPrec); + m_writer->emit("["); + emitOperand(inst->getOperand(1), EmitOpInfo()); + m_writer->emit("]"); + return true; + } + break; + + case kIROp_Rsh: + case kIROp_Lsh: + { + // Shift amounts must be an unsigned type in WGSL + // https://www.w3.org/TR/WGSL/#bit-expr + IRInst *const shiftAmount = inst->getOperand(1); + IRType *const shiftAmountType = shiftAmount->getDataType(); + if (shiftAmountType->getOp() == kIROp_IntType) + { + // Dawn complains about "mixing '<<' and '|' requires parenthesis", so let's + // add parenthesis. + m_writer->emit("("); + + const auto emitOp = getEmitOpForOp(inst->getOp()); + const auto info = getInfo(emitOp); + + const bool needClose = maybeEmitParens(outerPrec, info); + emitOperand(inst->getOperand(0), leftSide(outerPrec, info)); + m_writer->emit(" "); + m_writer->emit(info.op); + m_writer->emit(" "); + m_writer->emit("bitcast("); + emitOperand(inst->getOperand(1), rightSide(outerPrec, info)); + m_writer->emit(")"); + maybeCloseParens(needClose); + + m_writer->emit(")"); + return true; + } + } + break; + + } + + return false; +} + +void WGSLSourceEmitter::emitVectorTypeNameImpl( + IRType* elementType, IRIntegerValue elementCount + ) +{ + + if (elementCount > 1) + { + m_writer->emit("vec"); + m_writer->emit(elementCount); + m_writer->emit("<"); + emitSimpleType(elementType); + m_writer->emit(">"); + } + else + { + emitSimpleType(elementType); + } +} + +void WGSLSourceEmitter::emitOperandImpl(IRInst* inst, const EmitOpInfo& outerPrec) +{ + // In WGSL, the structured buffer types are converted to ptr, AM> + // everywhere, except for the global parameter declaration. + // Thus, when these globals are used in expressions, we need an ampersand. + + if (inst->getOp() == kIROp_GlobalParam) + { + switch (inst->getDataType()->getOp()) + { + case kIROp_HLSLStructuredBufferType: + case kIROp_HLSLRWStructuredBufferType: + + m_writer->emit("(&"); + CLikeSourceEmitter::emitOperandImpl(inst, outerPrec); + m_writer->emit(")"); + return; + } + } + + CLikeSourceEmitter::emitOperandImpl(inst, outerPrec); +} + +void WGSLSourceEmitter::emitGlobalParamType(IRType* type, const String& name) +{ + // In WGSL, the structured buffer types are converted to ptr, AM> + // everywhere, except for the global parameter declaration. + + switch (type->getOp()) + { + + case kIROp_HLSLStructuredBufferType: + case kIROp_HLSLRWStructuredBufferType: + { + StringSliceLoc nameAndLoc(name.getUnownedSlice()); + NameDeclaratorInfo nameDeclarator(&nameAndLoc); + emitDeclarator(&nameDeclarator); + m_writer->emit(" : "); + auto structuredBufferType = as(type); + m_writer->emit("array"); + m_writer->emit("<"); + emitType(structuredBufferType->getElementType()); + m_writer->emit(">"); + } + break; + + default: + + emitType(type, name); + break; + + } + +} + +void WGSLSourceEmitter::emitFrontMatterImpl(TargetRequest* /* targetReq */) +{ + if (m_f16ExtensionEnabled) + { + m_writer->emit("enable f16;\n"); + m_writer->emit("\n"); + } +} + +} // namespace Slang diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h new file mode 100644 index 0000000000..dacd11c3d3 --- /dev/null +++ b/source/slang/slang-emit-wgsl.h @@ -0,0 +1,71 @@ +#pragma once + +#include "slang-emit-c-like.h" + +namespace Slang +{ + +class WGSLSourceEmitter : public CLikeSourceEmitter +{ +public: + + WGSLSourceEmitter(const Desc& desc) + : CLikeSourceEmitter(desc) + {} + + virtual void emitParameterGroupImpl( + IRGlobalParam* varDecl, IRUniformParameterGroupType* type + ) SLANG_OVERRIDE; + virtual void emitEntryPointAttributesImpl( + IRFunc* irFunc, IREntryPointDecoration* entryPointDecor + ) SLANG_OVERRIDE; + virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; + virtual void emitVectorTypeNameImpl( + IRType* elementType, IRIntegerValue elementCount + ) SLANG_OVERRIDE; + virtual void emitFuncHeaderImpl(IRFunc* func) SLANG_OVERRIDE; + virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE; + virtual bool tryEmitInstExprImpl( + IRInst* inst, const EmitOpInfo& inOuterPrec + ) SLANG_OVERRIDE; + virtual void emitSwitchCaseSelectorsImpl( + IRBasicType *const switchCondition, + const SwitchRegion::Case *const currentCase, + const bool isDefault + ) SLANG_OVERRIDE; + virtual void emitSimpleTypeAndDeclaratorImpl( + IRType* type, DeclaratorInfo* declarator + ) SLANG_OVERRIDE; + virtual void emitVarKeywordImpl(IRType * type, const bool isConstant) SLANG_OVERRIDE; + virtual void emitDeclaratorImpl(DeclaratorInfo* declarator) SLANG_OVERRIDE; + virtual void emitStructDeclarationSeparatorImpl() SLANG_OVERRIDE; + virtual void emitLayoutQualifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE; + virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE; + virtual void emitParamTypeImpl(IRType* type, const String& name) SLANG_OVERRIDE; + virtual bool isPointerSyntaxRequiredImpl(IRInst* inst) SLANG_OVERRIDE; + virtual void _emitType(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE; + virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE; + virtual void emitStructFieldAttributes( + IRStructType * structType, IRStructField * field + ) SLANG_OVERRIDE; + virtual void emitGlobalParamType(IRType* type, const String& name) SLANG_OVERRIDE; + virtual void emitOperandImpl( + IRInst* inst, const EmitOpInfo& outerPrec + ) SLANG_OVERRIDE; + + void emit(const AddressSpace addressSpace); + +private: + + // Emit the matrix type with 'rowCountWGSL' WGSL-rows and 'colCountWGSL' WGSL-columns + void emitMatrixType( + IRType *const elementType, + const IRIntegerValue& rowCountWGSL, + const IRIntegerValue& colCountWGSL + ); + + bool m_f16ExtensionEnabled {false}; + +}; + +} // namespace Slang diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index ed9e904627..2ccf075f39 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -31,6 +31,7 @@ #include "slang-ir-glsl-legalize.h" #include "slang-ir-hlsl-legalize.h" #include "slang-ir-metal-legalize.h" +#include "slang-ir-wgsl-legalize.h" #include "slang-ir-insts.h" #include "slang-ir-inline.h" #include "slang-ir-legalize-array-return-type.h" @@ -101,6 +102,7 @@ #include "slang-emit-glsl.h" #include "slang-emit-hlsl.h" #include "slang-emit-metal.h" +#include "slang-emit-wgsl.h" #include "slang-emit-cpp.h" #include "slang-emit-cuda.h" #include "slang-emit-torch.h" @@ -1234,6 +1236,12 @@ Result linkAndOptimizeIR( } break; + case CodeGenTarget::WGSL: + { + legalizeIRForWGSL(irModule, sink); + } + break; + default: break; } @@ -1535,15 +1543,28 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr& outAr auto targetProgram = getTargetProgram(); auto lineDirectiveMode = targetProgram->getOptionSet().getEnumOption(CompilerOptionName::LineDirectiveMode); - // To try to make the default behavior reasonable, we will - // always use C-style line directives (to give the user - // good source locations on error messages from downstream - // compilers) *unless* they requested raw GLSL as the - // output (in which case we want to maximize compatibility - // with downstream tools). - if (lineDirectiveMode == LineDirectiveMode::Default && targetRequest->getTarget() == CodeGenTarget::GLSL) + // We will generally use C-style line directives in order to give the user good + // source locations on error messages from downstream compilers, but there are + // a few exceptions. + if (lineDirectiveMode == LineDirectiveMode::Default) { - lineDirectiveMode = LineDirectiveMode::GLSL; + + switch(targetRequest->getTarget()) + { + + case CodeGenTarget::GLSL: + // We want to maximize compatibility with downstream tools. + lineDirectiveMode = LineDirectiveMode::GLSL; + break; + + case CodeGenTarget::WGSL: + // WGSL doesn't support line directives. + // See https://github.com/gpuweb/gpuweb/issues/606. + lineDirectiveMode = LineDirectiveMode::None; + break; + + } + } ComPtr> sourceMap; @@ -1610,6 +1631,11 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr& outAr sourceEmitter = new MetalSourceEmitter(desc); break; } + case SourceLanguage::WGSL: + { + sourceEmitter = new WGSLSourceEmitter(desc); + break; + } default: break; } break; diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 5865d5320d..01b1c20dec 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -1511,6 +1511,7 @@ static bool doesTargetAllowUnresolvedFuncSymbol(TargetRequest* req) case CodeGenTarget::Metal: case CodeGenTarget::MetalLib: case CodeGenTarget::MetalLibAssembly: + case CodeGenTarget::WGSL: case CodeGenTarget::DXIL: case CodeGenTarget::DXILAssembly: case CodeGenTarget::HostCPPSource: diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index a480ae6737..d0ad7483a4 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -888,16 +888,19 @@ namespace Slang IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType) { - if (!isKhronosTarget(target->getTargetReq())) - return IRTypeLayoutRules::getNatural(); + if (target->getTargetReq()->getTarget() != CodeGenTarget::WGSL) + { + if (!isKhronosTarget(target->getTargetReq())) + return IRTypeLayoutRules::getNatural(); - // If we are just emitting GLSL, we can just use the general layout rule. - if (!target->shouldEmitSPIRVDirectly()) - return IRTypeLayoutRules::getNatural(); + // If we are just emitting GLSL, we can just use the general layout rule. + if (!target->shouldEmitSPIRVDirectly()) + return IRTypeLayoutRules::getNatural(); - // If the user specified a scalar buffer layout, then just use that. - if (target->getOptionSet().shouldUseScalarLayout()) - return IRTypeLayoutRules::getNatural(); + // If the user specified a scalar buffer layout, then just use that. + if (target->getOptionSet().shouldUseScalarLayout()) + return IRTypeLayoutRules::getNatural(); + } if (target->getOptionSet().shouldUseDXLayout()) { diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp new file mode 100644 index 0000000000..e05eba78c7 --- /dev/null +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -0,0 +1,347 @@ +#include "slang-ir-wgsl-legalize.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-parameter-binding.h" +#include "slang-ir-legalize-varying-params.h" + +namespace Slang +{ + + struct EntryPointInfo + { + IRFunc* entryPointFunc; + IREntryPointDecoration* entryPointDecor; + }; + + struct SystemValLegalizationWorkItem + { + IRInst* var; + String attrName; + UInt attrIndex; + }; + + struct WGSLSystemValueInfo + { + String wgslSystemValueName; + SystemValueSemanticName wgslSystemValueNameEnum; + ShortList permittedTypes; + bool isUnsupported = false; + }; + + struct LegalizeWGSLEntryPointContext + { + LegalizeWGSLEntryPointContext(DiagnosticSink* sink, IRModule* module) : + m_sink(sink), m_module(module) {} + + DiagnosticSink* m_sink; + IRModule* m_module; + + std::optional makeSystemValWorkItem(IRInst* var); + void legalizeSystemValue( + EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem + ); + List collectSystemValFromEntryPoint( + EntryPointInfo entryPoint + ); + void legalizeSystemValueParameters(EntryPointInfo entryPoint); + void legalizeEntryPointForWGSL(EntryPointInfo entryPoint); + IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType); + WGSLSystemValueInfo getSystemValueInfo( + String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar + ); + }; + + IRInst* LegalizeWGSLEntryPointContext::tryConvertValue( + IRBuilder& builder, IRInst* val, IRType* toType + ) + { + auto fromType = val->getFullType(); + if (auto fromVector = as(fromType)) + { + if (auto toVector = as(toType)) + { + if (fromVector->getElementCount() != toVector->getElementCount()) + { + fromType = + builder.getVectorType( + fromVector->getElementType(), toVector->getElementCount() + ); + val = builder.emitVectorReshape(fromType, val); + } + } + else if (as(toType)) + { + UInt index = 0; + val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index); + if (toType->getOp() == kIROp_VoidType) + return nullptr; + } + } + else if (auto fromBasicType = as(fromType)) + { + if (fromBasicType->getOp() == kIROp_VoidType) + return nullptr; + if (!as(toType)) + return nullptr; + if (toType->getOp() == kIROp_VoidType) + return nullptr; + } + else + { + return nullptr; + } + return builder.emitCast(toType, val); + } + + + WGSLSystemValueInfo LegalizeWGSLEntryPointContext::getSystemValueInfo( + String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar + ) + { + IRBuilder builder(m_module); + WGSLSystemValueInfo result = {}; + UnownedStringSlice semanticName; + UnownedStringSlice semanticIndex; + + auto hasExplicitIndex = + splitNameAndIndex( + inSemanticName.getUnownedSlice(), semanticName, semanticIndex + ); + if (!hasExplicitIndex && optionalSemanticIndex) + semanticIndex = optionalSemanticIndex->getUnownedSlice(); + + result.wgslSystemValueNameEnum = + convertSystemValueSemanticNameToEnum(semanticName); + + switch (result.wgslSystemValueNameEnum) + { + + case SystemValueSemanticName::DispatchThreadID: + { + result.wgslSystemValueName = toSlice("global_invocation_id"); + IRType *const vec3uType { + builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3) + ) + }; + result.permittedTypes.add(vec3uType); + } + break; + + case SystemValueSemanticName::GroupID: + { + result.wgslSystemValueName = toSlice("workgroup_id"); + result.permittedTypes.add( + builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3) + ) + ); + } + break; + + case SystemValueSemanticName::GroupThreadID: + { + result.wgslSystemValueName = toSlice("local_invocation_id"); + result.permittedTypes.add( + builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3) + ) + ); + } + break; + + case SystemValueSemanticName::GSInstanceID: + { + // No Geometry shaders in WGSL + result.isUnsupported = true; + } + break; + + default: + { + m_sink->diagnose( + parentVar, + Diagnostics::unimplementedSystemValueSemantic, semanticName + ); + return result; + } + + } + + return result; + } + + std::optional + LegalizeWGSLEntryPointContext::makeSystemValWorkItem(IRInst* var) + { + if (auto semanticDecoration = var->findDecoration()) + { + bool svPrefix = + semanticDecoration->getSemanticName().startsWithCaseInsensitive( + toSlice("sv_") + ); + if (svPrefix) + { + return + { + { + var, + String(semanticDecoration->getSemanticName()).toLower(), + (UInt)semanticDecoration->getSemanticIndex() + } + }; + } + } + + auto layoutDecor = var->findDecoration(); + if (!layoutDecor) + return {}; + auto sysValAttr = layoutDecor->findAttr(); + if (!sysValAttr) + return {}; + auto semanticName = String(sysValAttr->getName()); + auto sysAttrIndex = sysValAttr->getIndex(); + + return { { var, semanticName, sysAttrIndex } }; + } + + List + LegalizeWGSLEntryPointContext::collectSystemValFromEntryPoint( + EntryPointInfo entryPoint + ) + { + List systemValWorkItems; + for (auto param : entryPoint.entryPointFunc->getParams()) + { + auto maybeWorkItem = makeSystemValWorkItem(param); + if (maybeWorkItem.has_value()) + systemValWorkItems.add(std::move(maybeWorkItem.value())); + } + return systemValWorkItems; + } + + void + LegalizeWGSLEntryPointContext::legalizeSystemValue( + EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem + ) + { + IRBuilder builder(entryPoint.entryPointFunc); + + auto var = workItem.var; + auto semanticName = workItem.attrName; + + auto indexAsString = String(workItem.attrIndex); + auto info = getSystemValueInfo(semanticName, &indexAsString, var); + + if (!info.permittedTypes.getCount()) + return; + + builder.addTargetSystemValueDecoration( + var, info.wgslSystemValueName.getUnownedSlice() + ); + + bool varTypeIsPermitted = false; + auto varType = var->getFullType(); + for (auto& permittedType : info.permittedTypes) + { + varTypeIsPermitted = varTypeIsPermitted || permittedType == varType; + } + + if (!varTypeIsPermitted) + { + // Note: we do not currently prefer any conversion + // example: + // * allowed types for semantic: `float4`, `uint4`, `int4` + // * user used, `float2` + // * Slang will equally prefer `float4` to `uint4` to `int4`. + // This means the type may lose data if slang selects `uint4` or `int4`. + bool foundAConversion = false; + for (auto permittedType : info.permittedTypes) + { + var->setFullType(permittedType); + builder.setInsertBefore( + entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst() + ); + + // get uses before we `tryConvertValue` since this creates a new use + List uses; + for (auto use = var->firstUse; use; use = use->nextUse) + uses.add(use); + + auto convertedValue = tryConvertValue(builder, var, varType); + if (convertedValue == nullptr) + continue; + + foundAConversion = true; + copyNameHintAndDebugDecorations(convertedValue, var); + + for (auto use : uses) + builder.replaceOperand(use, convertedValue); + } + if (!foundAConversion) + { + // If we can't convert the value, report an error. + for (auto permittedType : info.permittedTypes) + { + StringBuilder typeNameSB; + getTypeNameHint(typeNameSB, permittedType); + m_sink->diagnose( + var->sourceLoc, + Diagnostics::systemValueTypeIncompatible, + semanticName, + typeNameSB.produceString() + ); + } + } + } + } + + void LegalizeWGSLEntryPointContext::legalizeSystemValueParameters( + EntryPointInfo entryPoint + ) + { + List systemValWorkItems = + collectSystemValFromEntryPoint(entryPoint); + + for (auto index = 0; index < systemValWorkItems.getCount(); index++) + { + legalizeSystemValue(entryPoint, systemValWorkItems[index]); + } + } + + void LegalizeWGSLEntryPointContext::legalizeEntryPointForWGSL( + EntryPointInfo entryPoint + ) + { + legalizeSystemValueParameters(entryPoint); + } + + void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) + { + List entryPoints; + for (auto inst : module->getGlobalInsts()) + { + IRFunc *const func {as(inst)}; + if (!func) + continue; + IREntryPointDecoration *const entryPointDecor = + func->findDecoration(); + if (!entryPointDecor) + continue; + EntryPointInfo info; + info.entryPointDecor = entryPointDecor; + info.entryPointFunc = func; + entryPoints.add(info); + } + + LegalizeWGSLEntryPointContext context(sink, module); + for (auto entryPoint : entryPoints) + context.legalizeEntryPointForWGSL(entryPoint); + } + +} diff --git a/source/slang/slang-ir-wgsl-legalize.h b/source/slang/slang-ir-wgsl-legalize.h new file mode 100644 index 0000000000..462f932044 --- /dev/null +++ b/source/slang/slang-ir-wgsl-legalize.h @@ -0,0 +1,10 @@ +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ + class DiagnosticSink; + + void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink); +} diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h index 04d4f5112c..178fbddd5e 100644 --- a/source/slang/slang-profile.h +++ b/source/slang/slang-profile.h @@ -19,6 +19,7 @@ namespace Slang CUDA = SLANG_SOURCE_LANGUAGE_CUDA, SPIRV = SLANG_SOURCE_LANGUAGE_SPIRV, Metal = SLANG_SOURCE_LANGUAGE_METAL, + WGSL = SLANG_SOURCE_LANGUAGE_WGSL, CountOf = SLANG_SOURCE_LANGUAGE_COUNT_OF, }; diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index f654135a14..2447f5787c 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1831,6 +1831,7 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe case CodeGenTarget::GLSL: case CodeGenTarget::SPIRV: case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::WGSL: return &kGLSLLayoutRulesFamilyImpl; case CodeGenTarget::HostHostCallable: @@ -2141,6 +2142,10 @@ SourceLanguage getIntermediateSourceLanguageForTarget(TargetProgram* targetProgr { return SourceLanguage::Metal; } + case CodeGenTarget::WGSL: + { + return SourceLanguage::WGSL; + } case CodeGenTarget::CSource: { return SourceLanguage::C; diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 91ed3de5fd..c78348a869 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1838,6 +1838,10 @@ CapabilitySet TargetRequest::getTargetCaps() atoms.add(CapabilityName::metal); break; + case CodeGenTarget::WGSL: + atoms.add(CapabilityName::wgsl); + break; + default: break; }