-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[SPIR-V] Implement SPV_KHR_float_controls2 #146941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPIR-V] Implement SPV_KHR_float_controls2 #146941
Conversation
|
@MrSidims @Keenuts @michalpaszkowski please, take a look. I cannot ask for reviews directly through the tool. |
|
✅ With the latest revision this PR passed the undef deprecator. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
The failed test is known, and caused by an update to spirv-val: #144774 (comment), and is unrelated to this change. |
MrSidims
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll take a closer look during next week. Some brief comments are listed below, the major one is a request to update SPIR-V friendly LLVM IR doc for the backend.
llvm/lib/Target/SPIRV/SPIRVUtils.cpp
Outdated
| return 1; | ||
| case 64: // double | ||
| return 2; | ||
| case 128: // fp128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does SPIR-V BE handle 128-bit fp type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have an offline discussion and came to the conclusion that 128-bit fp type is not supported by SPIRV. In consequence, I will remove everything related to it in this PR, both in implementation and tests.
| !7 = !{i32 32, i32 36} | ||
| !8 = !{i32 0, i32 0} | ||
| !12 = !{i32 1, !"wchar_size", i32 4} | ||
| !13 = !{!"clang version 8.0.1"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
guess we can remove unnecessary metadata
| unsigned ExpectMDOps, int64_t DefVal) { | ||
| MCInst Inst; | ||
| Inst.setOpcode(SPIRV::OpExecutionMode); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ; CHECK-DAG: OpExecutionMode %[[#KERNEL_HALF]] FPFastMathDefault %[[#HALF_TYPE:]] 1 | ||
| !17 = !{ptr @k_float_controls_half, i32 6028, half poison, i32 1} | ||
|
|
||
| ; CHECK-DAG: OpExecutionMode %[[#KERNEL_BFLOAT]] FPFastMathDefault %[[#BFLOAT_TYPE:]] 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tingles my spider senses, bfloat type (aka encoding from SPV_KHR_bfloat extension) is not supported by the BE. I assume llvm bfloat is lowered just to TypeFloat 32 then. If it's the case, why [[#BFLOAT_TYPE:]] entry is different from [[#FLOAT_TYPE:]] entry in the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bfloat is lowered to OpTypeFloat 16. [[#BFLOAT_TYPE:]] is different from [[#HALF_TYPE:]] in case it is eventually lowered to a different type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OpTypeFloat 16
With encoding different from IEEE-754, right? Otherwise it's a bug in lowering (definitely not related to your patch, yet nice to know ahead of time).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, no. There is only a single OpTypeFloat 16 without any encoding at all.
Keenuts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Started to review but I have a more global question first:
Seems like you do some stuff multiple times:
- check the metadata on the module, and emit default decorations (
outputFPFastMathDefaultInfo) - check metadata on all the rest, and emit decorations (later in the same
SPIRVAsmPrinterfunction) - then merge flags you find from the OpDecorate instruction, checking for conflicts
Code seems a bit convoluted to me with parts that could definitely be refactored (ex the iteration on each type and MI construction for the decoration)
Looks like the metadata gathering, flag conflict validation and only then decoration emission could be split into distinct, independent pieces.
Another question: reading the float_control2 spec, seems like it could be enabled for Shaders since the capability is not gated by the Kernel one. But looks like the SPIRVAsmPrinter part is gated on ST->isKernel
| if (Flags & SPIRV::FPFastMathMode::Fast) | ||
| report_fatal_error( | ||
| "FPFastMathMode::Fast flag is deprecated and it is not " | ||
| "valid to use anymore."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this is not a runtime error: only an error in this function would result in Fast being set with CanUseKHRFloatControls2, so assert would be more appropriate.
| // Error out if AllowTransform is enabled without AllowReassoc and | ||
| // AllowContract. | ||
| if ((Flags & SPIRV::FPFastMathMode::AllowTransform) && | ||
| (!(Flags & SPIRV::FPFastMathMode::AllowReassoc) || | ||
| !(Flags & SPIRV::FPFastMathMode::AllowContract))) | ||
| report_fatal_error( | ||
| "FPFastMathMode::AllowTransform flag requires AllowReassoc and " | ||
| "AllowContract flags to be enabled as well."); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, this seems to be a code error, not a runtime error.
| %remRes = frem arcp float %1, %2 | ||
| %negRes = fneg fast float %1 | ||
| %oeqRes = fcmp nnan ninf oeq float %1, %2 | ||
| %oneRes = fcmp one float %1, %2, !spirv.Decorations !3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this code were to be pass to the optimizer, would it be possible that the metadata will be removed? Would that be a problem? Are we worried about that or is is just an acceptable lose of information?
Also, could an llvm optimization optimize the instruction in such a way that is inconsistent with the float controls? I'm guessing no. Adding Floatcontrols makes things more permissible. It is the responsibility of the code generates to make the llvm-ir option at least as restrictive as the Floatcontrols.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this code were to be pass to the optimizer, would it be possible that the metadata will be removed? Would that be a problem? Are we worried about that or is is just an acceptable lose of information?
I have been having some internal discussions with @MrSidims on how to reconcile the information when there are both math flags expressed with LLVM IR and also through metadata. I don't think there's any formal guidance on how this should be done, and what should have priority, so I think we should come up with a convention ourselves. That being said, the scenario you present is quite helpful towards building this convention: if such a code was passed to the optimizer, with both LLVM IR math flags and metadata, the optimizer knows nothing about SPIRV's metadata --as far as I'm concerned--, so it could already do some transformations/optimizations that go against what is expressed in the metadata. As a consequence, I would say we should just emit the metadata the way we find it, because at the point where we find it, it might have been violated already, but we don't really know. What do you think? Another option is not even emitting it, because it could have been already violated, so what's the point in handling it.
Also, could an llvm optimization optimize the instruction in such a way that is inconsistent with the float controls? I'm guessing no. Adding Floatcontrols makes things more permissible. It is the responsibility of the code generates to make the llvm-ir option at least as restrictive as the Floatcontrols.
I think nothing prevents an optimization to go against what is expressed through metadata, because, like I said before, I don't think they're even aware of SPIRV's metadata. However, I do agree with you that the code generator should guarantee that LLVM IR flags are aligned with metadata, if any. I don't think there's a lot we can do on our side if they don't align, honestly. Like I said before, I did have a think at how to reconcile possible inconsistencies between LLVM IR flags and metadata, but I didn't see any good way of doing it, and, again, optimizations might have violated already metadata controls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that is the place, where SPIR-V to LLVM IR translator and SPIR-V BE behaviours should diverge.
It so happen, that the translator just translates any spirv.Decoration metadata (in case if the appropriate decoration was registered) out of a box (and may be should consider disallowing spirv.Decoration notation for fast math flags there).
Meanwhile if SPIR-V BE requires to explicitly write handlers for the metadata to lower FP flags - just don't do that and don't handle it in the backend.
The formal guideline says:
"As one of the goals of SPIR-V is to "map easily to other IRs, including LLVM IR", most of SPIR-V entities (global variables, constants, types, functions, basic blocks, instructions) have straightforward counterparts in LLVM. Therefore the focus of this document is those entities in SPIR-V which do not map to LLVM in an obvious way."
So as FP flags has obvious match between LLVM IR and SPIR-V (with an exception of reassoc <-> AllowTransform, but it's unrelated to this discussion) we should be fine to drop support for spirv.Decoration for FP flags in this PR. If anybody in the future would find it useful - they can re-add this support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, so, to be clear, going forward we want to ignore every spirv.Decoration metadata that is related to FP flags, including:
- FPFastMathMode
- NoContraction
I will address this in this PR. Please, speak up if you're against this.
|
@llvm/pr-subscribers-backend-spir-v Author: Marcos Maronas (maarquitos14) ChangesImplementation of SPV_KHR_float_controls2 extension, and corresponding tests. Some of the tests make use of
Patch is 66.26 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146941.diff 19 Files Affected:
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 1f563fbfb725a..3f6fa241da8ac 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -218,7 +218,7 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
* - ``SPV_INTEL_int4``
- Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
* - ``SPV_KHR_float_controls2``
- - Adds ability to specify the floating-point environment in shaders. It can be used on whole modules and individual instructions.
+ - Adds execution modes and decorations to control floating-point computations in both kernels and shaders. It can be used on whole modules and individual instructions.
To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:
@@ -585,3 +585,31 @@ Group and Subgroup Operations
For workgroup and subgroup operations, LLVM uses function calls to represent SPIR-V's
group-based instructions. These builtins facilitate group synchronization, data sharing,
and collective operations essential for efficient parallel computation.
+
+SPIR-V Instructions Mapped to LLVM Metadata
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Some SPIR-V instructions don't have a direct equivalent in the LLVM IR language. To
+address this, the SPIR-V Target uses different specific LLVM named metadata to convey
+the necessary information. The SPIR-V specification allows multiple module-scope
+instructions, where as LLVM named metadata must be unique. Therefore, the encoding of
+such instructions has the following format:
+
+.. code-block:: llvm
+
+ !spirv.<OpCodeName> = !{!<InstructionMetadata1>, !<InstructionMetadata2>, ..}
+ !<InstructionMetadata1> = !{<Operand1>, <Operand2>, ..}
+ !<InstructionMetadata2> = !{<Operand1>, <Operand2>, ..}
+
+Below, you will find the mappings between SPIR-V instruction and their corresponding
+LLVM IR representations.
+
++--------------------+---------------------------------------------------------+
+| SPIR-V instruction | LLVM IR |
++====================+=========================================================+
+| OpExecutionMode | .. code-block:: llvm |
+| | |
+| | !spirv.ExecutionMode = !{!0} |
+| | !0 = !{void @worker, i32 30, i32 262149} |
+| | ; Set execution mode with id 30 (VecTypeHint) and |
+| | ; literal `262149` operand. |
++--------------------+---------------------------------------------------------+
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
index 1ebfde2a603b9..24e4e390f98f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
@@ -80,6 +80,7 @@ class SPIRVAsmPrinter : public AsmPrinter {
void outputExecutionMode(const Module &M);
void outputAnnotations(const Module &M);
void outputModuleSections();
+ void outputFPFastMathDefaultInfo();
bool isHidden() {
return MF->getFunction()
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
@@ -497,11 +498,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
if (Node) {
for (unsigned i = 0; i < Node->getNumOperands(); i++) {
+ // If SPV_KHR_float_controls2 is enabled and we find any of
+ // FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution
+ // modes, skip it, it'll be done somewhere else.
+ if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
+ const auto EM =
+ cast<ConstantInt>(
+ cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1))
+ ->getValue())
+ ->getZExtValue();
+ if (EM == SPIRV::ExecutionMode::FPFastMathDefault ||
+ EM == SPIRV::ExecutionMode::ContractionOff ||
+ EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve)
+ continue;
+ }
+
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI);
outputMCInst(Inst);
}
+ outputFPFastMathDefaultInfo();
}
for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
const Function &F = *FI;
@@ -551,12 +568,85 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
}
if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") &&
!M.getNamedMetadata("opencl.enable.FP_CONTRACT")) {
- MCInst Inst;
- Inst.setOpcode(SPIRV::OpExecutionMode);
- Inst.addOperand(MCOperand::createReg(FReg));
- unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
- Inst.addOperand(MCOperand::createImm(EM));
- outputMCInst(Inst);
+ if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
+ // When SPV_KHR_float_controls2 is enabled, ContractionOff is
+ // deprecated. We need to use FPFastMathDefault with the appropriate
+ // flags instead. Since FPFastMathDefault takes a target type, we need
+ // to emit it for each floating-point type to match the effect of
+ // ContractionOff. As of now, there are 4 FP types: fp16, fp32, fp64 and
+ // fp128.
+ constexpr size_t NumFPTypes = 4;
+ for (size_t i = 0; i < NumFPTypes; ++i) {
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpExecutionMode);
+ Inst.addOperand(MCOperand::createReg(FReg));
+ unsigned EM =
+ static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
+ Inst.addOperand(MCOperand::createImm(EM));
+
+ Type *TargetType = nullptr;
+ switch (i) {
+ case 0:
+ TargetType = Type::getHalfTy(M.getContext());
+ break;
+ case 1:
+ TargetType = Type::getFloatTy(M.getContext());
+ break;
+ case 2:
+ TargetType = Type::getDoubleTy(M.getContext());
+ break;
+ case 3:
+ TargetType = Type::getFP128Ty(M.getContext());
+ break;
+ }
+ assert(TargetType && "Invalid target type for FPFastMathDefault");
+
+ // Find the SPIRV type matching the target type. We'll go over all the
+ // TypeConstVars instructions in the SPIRV module and find the one
+ // that matches the target type. We know the target type is a
+ // floating-point type, so we can skip anything different than
+ // OpTypeFloat. Then, we need to check the bitwidth.
+ bool SPIRVTypeFound = false;
+ for (const MachineInstr *MI :
+ MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
+ // Skip if the instruction is not OpTypeFloat.
+ if (MI->getOpcode() != SPIRV::OpTypeFloat)
+ continue;
+
+ // Skip if TargetTy bitwidth doesn't match MI->getOperand(1), which
+ // is the SPIRV type bit width.
+ if (TargetType->getScalarSizeInBits() != MI->getOperand(1).getImm())
+ continue;
+
+ SPIRVTypeFound = true;
+ const MachineFunction *MF = MI->getMF();
+ MCRegister TypeReg =
+ MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
+ Inst.addOperand(MCOperand::createReg(TypeReg));
+ }
+
+ if (!SPIRVTypeFound) {
+ // The module does not contain this FP type, so we don't need to
+ // emit FPFastMathDefault for it.
+ continue;
+ }
+ // We only end up here because there is no "spirv.ExecutionMode"
+ // metadata, so that means no FPFastMathDefault. Therefore, we only
+ // need to make sure AllowContract is set to 0, as the rest of flags.
+ // We still need to emit the OpExecutionMode instruction, otherwise
+ // it's up to the client API to define the flags.
+ Inst.addOperand(MCOperand::createImm(SPIRV::FPFastMathMode::None));
+ outputMCInst(Inst);
+ }
+ } else {
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpExecutionMode);
+ Inst.addOperand(MCOperand::createReg(FReg));
+ unsigned EM =
+ static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
+ Inst.addOperand(MCOperand::createImm(EM));
+ outputMCInst(Inst);
+ }
}
}
}
@@ -603,6 +693,80 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
}
}
+void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() {
+ for (const auto &[Func, FPFastMathDefaultInfoVec] :
+ MAI->FPFastMathDefaultInfoMap) {
+ for (const auto &FPFastMathDefaultInfo : FPFastMathDefaultInfoVec) {
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpExecutionMode);
+ MCRegister FuncReg = MAI->getFuncReg(Func);
+ assert(FuncReg.isValid());
+ Inst.addOperand(MCOperand::createReg(FuncReg));
+ Inst.addOperand(
+ MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault));
+
+ // Find the SPIRV type matching the target type. We'll go over all the
+ // TypeConstVars instructions in the SPIRV module and find the one that
+ // matches the target type. We know the target type is a floating-point
+ // type, so we can skip anything different than OpTypeFloat. Then, we
+ // need to check the bitwidth.
+ const Type *TargetTy = FPFastMathDefaultInfo.Ty;
+ assert(TargetTy && "Expected target type");
+ bool SPIRVTypeFound = false;
+ for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
+ // Skip if the instruction is not OpTypeFloat.
+ if (MI->getOpcode() != SPIRV::OpTypeFloat)
+ continue;
+
+ // Skip if TargetTy bitwidth doesn't match MI->getOperand(1), which is
+ // the SPIRV type bit width.
+ if (TargetTy->getScalarSizeInBits() != MI->getOperand(1).getImm())
+ continue;
+
+ SPIRVTypeFound = true;
+ const MachineFunction *MF = MI->getMF();
+ MCRegister TypeReg =
+ MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
+ Inst.addOperand(MCOperand::createReg(TypeReg));
+ }
+ if (!SPIRVTypeFound) {
+ // The module does not contain this FP type, so we don't need to emit
+ // FPFastMathDefault for it.
+ continue;
+ }
+
+ unsigned Flags = FPFastMathDefaultInfo.FastMathFlags;
+ if (FPFastMathDefaultInfo.ContractionOff &&
+ (Flags & SPIRV::FPFastMathMode::AllowContract) &&
+ FPFastMathDefaultInfo.FPFastMathDefault)
+ report_fatal_error(
+ "Conflicting FPFastMathFlags: ContractionOff and AllowContract");
+
+ if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
+ !(Flags &
+ (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
+ SPIRV::FPFastMathMode::NSZ))) {
+ if (FPFastMathDefaultInfo.FPFastMathDefault)
+ report_fatal_error("Conflicting FPFastMathFlags: "
+ "SignedZeroInfNanPreserve but at least one of "
+ "NotNaN/NotInf/NSZ is disabled.");
+
+ Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
+ SPIRV::FPFastMathMode::NSZ;
+ }
+
+ // Don't emit if none of the execution modes was used.
+ if (Flags == SPIRV::FPFastMathMode::None &&
+ !FPFastMathDefaultInfo.ContractionOff &&
+ !FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
+ !FPFastMathDefaultInfo.FPFastMathDefault)
+ continue;
+ Inst.addOperand(MCOperand::createImm(Flags));
+ outputMCInst(Inst);
+ }
+ }
+}
+
void SPIRVAsmPrinter::outputModuleSections() {
const Module *M = MMI->getModule();
// Get the global subtarget to output module-level info.
@@ -611,7 +775,8 @@ void SPIRVAsmPrinter::outputModuleSections() {
MAI = &SPIRVModuleAnalysis::MAI;
assert(ST && TII && MAI && M && "Module analysis is required");
// Output instructions according to the Logical Layout of a Module:
- // 1,2. All OpCapability instructions, then optional OpExtension instructions.
+ // 1,2. All OpCapability instructions, then optional OpExtension
+ // instructions.
outputGlobalRequirements();
// 3. Optional OpExtInstImport instructions.
outputOpExtInstImports(*M);
@@ -619,7 +784,8 @@ void SPIRVAsmPrinter::outputModuleSections() {
outputOpMemoryModel();
// 5. All entry point declarations, using OpEntryPoint.
outputEntryPoints();
- // 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId.
+ // 6. Execution-mode declarations, using OpExecutionMode or
+ // OpExecutionModeId.
outputExecutionMode(*M);
// 7a. Debug: all OpString, OpSourceExtension, OpSource, and
// OpSourceContinued, without forward references.
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 6ec7544767c52..280a0197513c0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -697,7 +697,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
if (Call->isSpirvOp())
- return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
+ return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call,
+ Register(0));
Register ScopeRegister =
buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
@@ -1125,11 +1126,24 @@ static unsigned getNumSizeComponents(SPIRVType *imgType) {
static bool generateExtInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
- SPIRVGlobalRegistry *GR) {
+ SPIRVGlobalRegistry *GR, const CallBase &CB) {
// Lookup the extended instruction number in the TableGen records.
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
uint32_t Number =
SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number;
+ // fmin_common and fmax_common are now deprecated, and we should use fmin and
+ // fmax with NotInf and NotNaN flags instead. Keep original number to add
+ // later the NoNans and NoInfs flags.
+ uint32_t OrigNumber = Number;
+ const SPIRVSubtarget &ST =
+ cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+ if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2) &&
+ (Number == SPIRV::OpenCLExtInst::fmin_common ||
+ Number == SPIRV::OpenCLExtInst::fmax_common)) {
+ Number = (Number == SPIRV::OpenCLExtInst::fmin_common)
+ ? SPIRV::OpenCLExtInst::fmin
+ : SPIRV::OpenCLExtInst::fmax;
+ }
// Build extended instruction.
auto MIB =
@@ -1141,6 +1155,13 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call,
for (auto Argument : Call->Arguments)
MIB.addUse(Argument);
+ MIB.getInstr()->copyIRFlags(CB);
+ if (OrigNumber == SPIRV::OpenCLExtInst::fmin_common ||
+ OrigNumber == SPIRV::OpenCLExtInst::fmax_common) {
+ // Add NoNans and NoInfs flags to fmin/fmax instruction.
+ MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoNans);
+ MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoInfs);
+ }
return true;
}
@@ -2844,7 +2865,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
MachineIRBuilder &MIRBuilder,
const Register OrigRet, const Type *OrigRetTy,
const SmallVectorImpl<Register> &Args,
- SPIRVGlobalRegistry *GR) {
+ SPIRVGlobalRegistry *GR, const CallBase &CB) {
LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
// Lookup the builtin in the TableGen records.
@@ -2867,7 +2888,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
// Match the builtin with implementation based on the grouping.
switch (Call->Builtin->Group) {
case SPIRV::Extended:
- return generateExtInst(Call.get(), MIRBuilder, GR);
+ return generateExtInst(Call.get(), MIRBuilder, GR, CB);
case SPIRV::Relational:
return generateRelationalInst(Call.get(), MIRBuilder, GR);
case SPIRV::Group:
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 1a8641a8328dd..f6a5234cd3c73 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -39,7 +39,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
MachineIRBuilder &MIRBuilder,
const Register OrigRet, const Type *OrigRetTy,
const SmallVectorImpl<Register> &Args,
- SPIRVGlobalRegistry *GR);
+ SPIRVGlobalRegistry *GR, const CallBase &CB);
/// Helper function for finding a builtin function attributes
/// by a demangled function name. Defined in SPIRVBuiltins.cpp.
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index a412887e51adb..1a7c02c676465 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -641,9 +641,9 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
GR->getPointerSize()));
}
}
- if (auto Res =
- SPIRV::lowerBuiltin(DemangledName, ST->getPreferredInstructionSet(),
- MIRBuilder, ResVReg, OrigRetTy, ArgVRegs, GR))
+ if (auto Res = SPIRV::lowerBuiltin(
+ DemangledName, ST->getPreferredInstructionSet(), MIRBuilder,
+ ResVReg, OrigRetTy, ArgVRegs, GR, *Info.CB))
return *Res;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 83fccdc2bdba3..bc275b09674be 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -794,7 +794,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
// arguments.
MDNode *GVarMD = nullptr;
if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr)
- buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD);
+ buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD, ST);
return Reg;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index f658b67a4c2a5..357aab2f580c9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -130,7 +130,8 @@ bool SPIRVInstrInfo::isHeaderInstr(const MachineInstr &MI) const {
}
}
-bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI) const {
+bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI,
+ bool KHRFloatControls2) const {
switch (MI.getOpcode()) {
case SPIRV::OpFAddS:
case SPIRV::OpFSubS:
@@ -144,6 +145,24 @@ bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI) const {
case SPIRV::OpFRemV:
case SPIRV::OpFMod:
return true;
+ case SPIRV::OpFNegateV:
+ case SPIRV::OpFNegate:
+ case SPIRV::OpOrdered:
+ case SPIRV::OpUnordered:
+ case SPIRV::OpFOrdEqual:
+ case SPIRV::OpFOrdNotEqual:
+ case SPIRV::OpFOrdLessThan:
+ case SPIRV::OpFOrdLessThanEqual:
+ case SPIRV::OpFOrdGreaterThan:
+ case SPIRV::OpFOrdGreaterThanEqual:
+ case SPIRV::OpFUnordEqual:
+ case SPIRV::OpFUnordNotEqual:
+ case SPIRV::OpFUnordLessThan:
+ case SPIRV::OpFUnordLessThanEqual:
+ case SPIRV::OpFUnordGreaterThan:
+ case SPIRV::OpFUnordGreaterThanEqual:
+ case SPIRV::OpExtInst:
+ return KHRFloatControls2 ? true : false;
default:
return false;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h
in...
[truncated]
|
|
@Keenuts I have updated this PR to make |
| // first element is the smallest bit width, and the last element is the | ||
| // largest bit width, therefore, we will have {half, float, double} in | ||
| // the order of their bit widths. | ||
| DenseMap<const Function *, SmallVector<FPFastMathDefaultInfo, 3>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'll be in favor or having SmallVector<FPFastMathDefaultInfo, 3> changed into:
struct FPFastMathDefaultInfoVector : public SmallVector<FPFastMathDefaultInfo, 3> {
static size_t computeFPFastMathDefaultInfoVecIndex(size_t BitWidth) {
// the code from SPIRVUtilsThis way you keep the index magic numbers close to the object using them.
Also, seems like there is a reimplementation of the FPFastMathDefaultInfo in the emit intrinsincs, meaning this couldn't be reused there. But once this code duplication removed, all SmallVector<Info, 3> bits can be replaced with this new class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added that to SPIRVUtils.
| } | ||
| }; | ||
|
|
||
| struct FPFastMathDefaultInfo { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an exact copy of the FPFastMathDefaultInfo struct in the ModuleAnalysis? Why not reuse it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved all to SPIRVUtils.
| continue; | ||
|
|
||
| if (OpCode == SPIRV::OpConstantI) | ||
| ConstMap[MI->getOperand(2).getImm()] = MI; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ints could be signed, so here seems like you could end up with a -1 being added as a max_uint or something like that no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the key of the map from unsigned to int.
|
@Keenuts friendly ping :) |
|
@Keenuts is on vacation for a while. You seem to have addressed his comment. You do not have to wait for him to merge. |
MrSidims
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test adjustment in fe49f8b is LGTM
|
Test failure is unrelated, and being addressed by #158086, so I'm merging this. |
Implementation of [SPV_KHR_float_controls2](https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_float_controls2.html) extension, and corresponding tests. Some of the tests make use of `!spirv.ExecutionMode` LLVM named metadata. This is because some SPIR-V instructions don't have a direct equivalent in LLVM IR, so the SPIR-V Target uses different LLVM named metadata to convey the necessary information. Below, you will find an example from one of the newly added tests: ``` !spirv.ExecutionMode = !{!19, !20, !21, !22, !23, !24, !25, !26, !27} !19 = !{ptr @k_float_controls_float, i32 6028, float poison, i32 131079} !20 = !{ptr @k_float_controls_all, i32 6028, float poison, i32 131079} !21 = !{ptr @k_float_controls_float, i32 31} !22 = !{ptr @k_float_controls_all, i32 31} !23 = !{ptr @k_float_controls_float, i32 4461, i32 32} !24 = !{ptr @k_float_controls_all, i32 4461, i32 16} !25 = !{ptr @k_float_controls_all, i32 4461, i32 32} !26 = !{ptr @k_float_controls_all, i32 4461, i32 64} !27 = !{ptr @k_float_controls_all, i32 4461, i32 128} ``` `!spirv.ExecutionMode` contains a list of metadata nodes, and each of them specifies the required operands for expressing a particular `OpExecutionMode` instruction in SPIR-V. For example, `!19 = !{ptr @k_float_controls_float, i32 6028, float poison, i32 131079}` will be lowered to `OpExecutionMode [[k_float_controls_float_ID]] FPFastMathDefault [[float_type_ID]] 131079`. --------- Co-authored-by: Dmitry Sidorov <[email protected]>
Implementation of SPV_KHR_float_controls2 extension, and corresponding tests.
Some of the tests make use of
!spirv.ExecutionModeLLVM named metadata. This is because some SPIR-V instructions don't have a direct equivalent in LLVM IR, so the SPIR-V Target uses different LLVM named metadata to convey the necessary information. Below, you will find an example from one of the newly added tests:!spirv.ExecutionModecontains a list of metadata nodes, and each of them specifies the required operands for expressing a particularOpExecutionModeinstruction in SPIR-V. For example,!19 = !{ptr @k_float_controls_float, i32 6028, float poison, i32 131079}will be lowered toOpExecutionMode [[k_float_controls_float_ID]] FPFastMathDefault [[float_type_ID]] 131079.