Skip to content

Implementation of vector compute extension #1849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm-spirv/include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ EXT(SPV_INTEL_io_pipes)
EXT(SPV_INTEL_inline_assembly)
EXT(SPV_INTEL_arbitrary_precision_integers)
EXT(SPV_INTEL_optimization_hints)
EXT(SPV_INTEL_float_controls2)
EXT(SPV_INTEL_vector_compute)
1 change: 1 addition & 0 deletions llvm-spirv/lib/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_llvm_library(LLVMSPIRVLib
OCL21ToSPIRV.cpp
OCLTypeToSPIRV.cpp
OCLUtil.cpp
VectorComputeUtil.cpp
SPIRVLowerBool.cpp
SPIRVLowerConstExpr.cpp
SPIRVLowerMemmove.cpp
Expand Down
53 changes: 53 additions & 0 deletions llvm-spirv/lib/SPIRV/PreprocessMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "SPIRVInternal.h"
#include "SPIRVMDBuilder.h"
#include "SPIRVMDWalker.h"
#include "VectorComputeUtil.h"

#include "llvm/ADT/Triple.h"
#include "llvm/IR/IRBuilder.h"
Expand Down Expand Up @@ -69,6 +70,8 @@ class PreprocessMetadata : public ModulePass {
bool runOnModule(Module &M) override;
void visit(Module *M);
void preprocessOCLMetadata(Module *M, SPIRVMDBuilder *B, SPIRVMDWalker *W);
void preprocessVectorComputeMetadata(Module *M, SPIRVMDBuilder *B,
SPIRVMDWalker *W);

static char ID;

Expand Down Expand Up @@ -100,6 +103,7 @@ void PreprocessMetadata::visit(Module *M) {
SPIRVMDWalker W(*M);

preprocessOCLMetadata(M, &B, &W);
preprocessVectorComputeMetadata(M, &B, &W);

// Create metadata representing (empty so far) list
// of OpExecutionMode instructions
Expand Down Expand Up @@ -243,6 +247,55 @@ void PreprocessMetadata::preprocessOCLMetadata(Module *M, SPIRVMDBuilder *B,
B->eraseNamedMD(kSPIR2MD::FPContract);
}

void PreprocessMetadata::preprocessVectorComputeMetadata(Module *M,
SPIRVMDBuilder *B,
SPIRVMDWalker *W) {
using namespace VectorComputeUtil;

auto EM = B->addNamedMD(kSPIRVMD::ExecutionMode);

for (auto &F : *M) {
// Add VC float control execution modes
// RoundMode and FloatMode are always same for all types in VC
// While Denorm could be different for double, float and half
auto Attrs = F.getAttributes();
if (Attrs.hasFnAttribute(kVCMetadata::VCFloatControl)) {
SPIRVWord Mode = 0;
Attrs
.getAttribute(AttributeList::FunctionIndex,
kVCMetadata::VCFloatControl)
.getValueAsString()
.getAsInteger(0, Mode);
spv::ExecutionMode ExecRoundMode =
VCRoundModeExecModeMap::map(getVCRoundMode(Mode));
spv::ExecutionMode ExecFloatMode =
VCFloatModeExecModeMap::map(getVCFloatMode(Mode));
VCFloatTypeSizeMap::foreach (
[&](VCFloatType FloatType, unsigned TargetWidth) {
EM.addOp().add(&F).add(ExecRoundMode).add(TargetWidth).done();
EM.addOp().add(&F).add(ExecFloatMode).add(TargetWidth).done();
EM.addOp()
.add(&F)
.add(VCDenormModeExecModeMap::map(
getVCDenormPreserve(Mode, FloatType)))
.add(TargetWidth)
.done();
});
}
if (Attrs.hasFnAttribute(kVCMetadata::VCSLMSize)) {
SPIRVWord SLMSize = 0;
Attrs.getAttribute(AttributeList::FunctionIndex, kVCMetadata::VCSLMSize)
.getValueAsString()
.getAsInteger(0, SLMSize);
EM.addOp()
.add(&F)
.add(spv::ExecutionModeSharedLocalMemorySizeINTEL)
.add(SLMSize)
.done();
}
}
}

} // namespace SPIRV

INITIALIZE_PASS(PreprocessMetadata, "preprocess-metadata",
Expand Down
27 changes: 17 additions & 10 deletions llvm-spirv/lib/SPIRV/SPIRVLowerConstExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ bool SPIRVLowerConstExpr::runOnModule(Module &Module) {

void SPIRVLowerConstExpr::visit(Module *M) {
for (auto &I : M->functions()) {
std::map<ConstantExpr *, Instruction *> CMap;
std::list<Instruction *> WorkList;
for (auto &BI : I) {
for (auto &II : BI) {
Expand All @@ -124,7 +123,10 @@ void SPIRVLowerConstExpr::visit(Module *M) {
while (!WorkList.empty()) {
auto II = WorkList.front();

auto LowerOp = [&II, &FBegin, &I](ConstantExpr *CE) {
auto LowerOp = [&II, &FBegin, &I](Value *V) -> Value * {
if (isa<Function>(V))
return V;
auto *CE = cast<ConstantExpr>(V);
SPIRVDBG(dbgs() << "[lowerConstantExpressions] " << *CE;)
auto ReplInst = CE->getAsInstruction();
auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
Expand All @@ -149,25 +151,30 @@ void SPIRVLowerConstExpr::visit(Module *M) {
for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
auto Op = II->getOperand(OI);
auto *Vec = dyn_cast<ConstantVector>(Op);
if (Vec && std::all_of(Vec->op_begin(), Vec->op_end(),
[](Value *V) { return isa<ConstantExpr>(V); })) {
if (Vec && std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
return isa<ConstantExpr>(V) || isa<Function>(V);
})) {
// Expand a vector of constexprs and construct it back with series of
// insertelement instructions
std::list<Instruction *> ReplList;
std::transform(
Vec->op_begin(), Vec->op_end(), std::back_inserter(ReplList),
[LowerOp](Value *V) { return LowerOp(cast<ConstantExpr>(V)); });
std::list<Value *> OpList;
std::transform(Vec->op_begin(), Vec->op_end(),
std::back_inserter(OpList),
[LowerOp](Value *V) { return LowerOp(V); });
Value *Repl = nullptr;
unsigned Idx = 0;
for (auto V : ReplList)
std::list<Instruction *> ReplList;
for (auto V : OpList) {
if (auto *Inst = dyn_cast<Instruction>(V))
ReplList.push_back(Inst);
Repl = InsertElementInst::Create(
(Repl ? Repl : UndefValue::get(Vec->getType())), V,
ConstantInt::get(Type::getInt32Ty(M->getContext()), Idx++), "",
II);
}
II->replaceUsesOfWith(Op, Repl);
WorkList.splice(WorkList.begin(), ReplList);
} else if (auto CE = dyn_cast<ConstantExpr>(Op))
WorkList.push_front(LowerOp(CE));
WorkList.push_front(cast<Instruction>(LowerOp(CE)));
}
}
}
Expand Down
100 changes: 96 additions & 4 deletions llvm-spirv/lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include "SPIRVType.h"
#include "SPIRVUtil.h"
#include "SPIRVValue.h"
#include "VectorComputeUtil.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/Analysis/LoopInfo.h"
Expand Down Expand Up @@ -108,7 +109,7 @@ const static char *Restrict = "restrict";
const static char *Pipe = "pipe";
} // namespace kOCLTypeQualifierName

static bool isOpenCLKernel(SPIRVFunction *BF) {
static bool isKernel(SPIRVFunction *BF) {
return BF->getModule()->isEntryPoint(ExecutionModelKernel, BF->getId());
}

Expand Down Expand Up @@ -1530,14 +1531,34 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
assert(BB && "Invalid BB");
return mapValue(BV, new AllocaInst(Ty, 0, BV->getName(), BB));
}
auto AddrSpace = SPIRSPIRVAddrSpaceMap::rmap(BS);
SPIRAddressSpace AddrSpace;

bool IsVectorCompute =
BVar->hasDecorate(DecorationVectorComputeVariableINTEL);
if (IsVectorCompute) {
AddrSpace = VectorComputeUtil::getVCGlobalVarAddressSpace(BS);
if (!Initializer)
Initializer = UndefValue::get(Ty);
} else
AddrSpace = SPIRSPIRVAddrSpaceMap::rmap(BS);

auto LVar = new GlobalVariable(*M, Ty, IsConst, LinkageTy, Initializer,
BV->getName(), 0,
GlobalVariable::NotThreadLocal, AddrSpace);
LVar->setUnnamedAddr((IsConst && Ty->isArrayTy() &&
Ty->getArrayElementType()->isIntegerTy(8))
? GlobalValue::UnnamedAddr::Global
: GlobalValue::UnnamedAddr::None);

if (IsVectorCompute) {
LVar->addAttribute(kVCMetadata::VCGlobalVariable);
SPIRVWord Offset;
if (BVar->hasDecorate(DecorationGlobalVariableOffsetINTEL, 0, &Offset))
LVar->addAttribute(kVCMetadata::VCByteOffset, utostr(Offset));
if (BVar->hasDecorate(DecorationVolatile))
LVar->addAttribute(kVCMetadata::VCVolatile);
}

SPIRVBuiltinVariableKind BVKind;
if (BVar->isBuiltin(&BVKind))
BuiltinGVMap[LVar] = BVKind;
Expand Down Expand Up @@ -2404,7 +2425,7 @@ Function *SPIRVToLLVM::transFunction(SPIRVFunction *BF) {
if (Loc != FuncMap.end())
return Loc->second;

auto IsKernel = BM->isEntryPoint(ExecutionModelKernel, BF->getId());
auto IsKernel = isKernel(BF);
auto Linkage = IsKernel ? GlobalValue::ExternalLinkage : transLinkageType(BF);
FunctionType *FT = dyn_cast<FunctionType>(transType(BF->getFunctionType()));
Function *F = cast<Function>(
Expand Down Expand Up @@ -3218,7 +3239,7 @@ bool SPIRVToLLVM::transFPContractMetadata() {
bool ContractOff = false;
for (unsigned I = 0, E = BM->getNumFunctions(); I != E; ++I) {
SPIRVFunction *BF = BM->getFunction(I);
if (!isOpenCLKernel(BF))
if (!isKernel(BF))
continue;
if (BF->getExecutionMode(ExecutionModeContractionOff)) {
ContractOff = true;
Expand Down Expand Up @@ -3296,6 +3317,7 @@ bool SPIRVToLLVM::transMetadata() {
assert(F && "Invalid translated function");

transOCLMetadata(BF);
transVectorComputeMetadata(BF);

if (F->getCallingConv() != CallingConv::SPIR_KERNEL)
continue;
Expand Down Expand Up @@ -3436,6 +3458,76 @@ bool SPIRVToLLVM::transOCLMetadata(SPIRVFunction *BF) {
return true;
}

bool SPIRVToLLVM::transVectorComputeMetadata(SPIRVFunction *BF) {
using namespace VectorComputeUtil;
Function *F = static_cast<Function *>(getTranslatedValue(BF));
assert(F && "Invalid translated function");

if (BF->hasDecorate(DecorationStackCallINTEL))
F->addFnAttr(kVCMetadata::VCStackCall);

bool IsVectorCompute = BF->hasDecorate(DecorationVectorComputeFunctionINTEL);
if (!IsVectorCompute)
return true;
F->addFnAttr(kVCMetadata::VCFunction);

for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
++I) {
auto ArgNo = I->getArgNo();
SPIRVFunctionParameter *BA = BF->getArgument(ArgNo);
SPIRVWord Kind;
if (BA->hasDecorate(DecorationFuncParamIOKind, 0, &Kind)) {
Attribute Attr = Attribute::get(*Context, kVCMetadata::VCArgumentIOKind,
std::to_string(Kind));
F->addAttribute(ArgNo + 1, Attr);
}
}

// Do not add float control if there is no any
bool IsVCFloatControl = false;
unsigned FloatControl = 0;
// RoundMode and FloatMode are always same for all types in Cm
// While Denorm could be different for double, float and half
VCRoundModeExecModeMap::foreach ([&](VCRoundMode VCRM, ExecutionMode EM) {
if (BF->getExecutionMode(EM)) {
IsVCFloatControl = true;
FloatControl |= getVCFloatControl(VCRM);
}
});
VCFloatModeExecModeMap::foreach ([&](VCFloatMode VCFM, ExecutionMode EM) {
if (BF->getExecutionMode(EM)) {
IsVCFloatControl = true;
FloatControl |= getVCFloatControl(VCFM);
}
});
VCDenormModeExecModeMap::foreach ([&](VCDenormMode VCDM, ExecutionMode EM) {
auto ExecModes = BF->getExecutionModeRange(EM);
for (auto It = ExecModes.first; It != ExecModes.second; It++) {
IsVCFloatControl = true;
unsigned TargetWidth = (*It).second->getLiterals()[0];
VCFloatType FloatType = VCFloatTypeSizeMap::rmap(TargetWidth);
FloatControl |= getVCFloatControl(VCDM, FloatType);
}
});
if (IsVCFloatControl) {
Attribute Attr = Attribute::get(*Context, kVCMetadata::VCFloatControl,
std::to_string(FloatControl));
F->addAttribute(AttributeList::FunctionIndex, Attr);
}

if (auto EM = BF->getExecutionMode(ExecutionModeSharedLocalMemorySizeINTEL)) {
unsigned int SLMSize = EM->getLiterals()[0];
Attribute Attr = Attribute::get(*Context, kVCMetadata::VCSLMSize,
std::to_string(SLMSize));
F->addAttribute(AttributeList::FunctionIndex, Attr);
}

if (F->getCallingConv() != CallingConv::SPIR_KERNEL)
return true;

return true;
}

bool SPIRVToLLVM::transAlign(SPIRVValue *BV, Value *V) {
if (auto AL = dyn_cast<AllocaInst>(V)) {
SPIRVWord Align = 0;
Expand Down
1 change: 1 addition & 0 deletions llvm-spirv/lib/SPIRV/SPIRVReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class SPIRVToLLVM {
bool transFPContractMetadata();
bool transMetadata();
bool transOCLMetadata(SPIRVFunction *BF);
bool transVectorComputeMetadata(SPIRVFunction *BF);
Value *transAsmINTEL(SPIRVAsmINTEL *BA);
CallInst *transAsmCallINTEL(SPIRVAsmCallINTEL *BI, Function *F,
BasicBlock *BB);
Expand Down
Loading