Skip to content

Conversation

@heiher
Copy link
Member

@heiher heiher commented Sep 20, 2024

LoongArch currently lacks a hardware extension for the fp16 data type, and the ABI documentation does not explicitly define how to handle fp16. Future revsions of the LoongArch specification will include conventions to address fp16 requirements.

Previously, we maintained the 'half' type in its 16-bit format between operations. Even with the F/D ABI, the value would be passed in the lower 16 bits of a GPR in its 'half' format.

With this patch, depending on the ABI in use, the value will be passed either in an FPR or a GPR in 'half' format. This ensures consistency with the bits location when the fp16 hardware extension is enabled.

@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2024

@llvm/pr-subscribers-backend-loongarch

Author: hev (heiher)

Changes

LoongArch currently lacks a hardware extension for the fp16 data type, and the ABI documentation does not explicitly define how to handle fp16. Future revsions of the LoongArch specification will include conventions to address fp16 requirements.

Previously, we maintained the 'half' type in its 16-bit format between operations. Regardless of whether the F extension is enabled, the value would be passed in the lower 16 bits of a GPR in its 'half' format.

With this patch, depending on the ABI in use, the value will be passed either in an FPR or a GPR in 'half' format. This ensures consistency with the bits location when the fp16 hardware extension is enabled.


Patch is 62.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109368.diff

4 Files Affected:

  • (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp (+140-5)
  • (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.h (+24)
  • (added) llvm/test/CodeGen/LoongArch/calling-conv-half.ll (+924)
  • (modified) llvm/test/CodeGen/LoongArch/fp16-promote.ll (+126-66)
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index bfafb331752108..b9dbb435f33c59 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -181,8 +181,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::FSINCOS, MVT::f32, Expand);
     setOperationAction(ISD::FPOW, MVT::f32, Expand);
     setOperationAction(ISD::FREM, MVT::f32, Expand);
-    setOperationAction(ISD::FP16_TO_FP, MVT::f32, Expand);
-    setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand);
+    setOperationAction(ISD::FP16_TO_FP, MVT::f32, Custom);
+    setOperationAction(ISD::FP_TO_FP16, MVT::f32, Custom);
 
     if (Subtarget.is64Bit())
       setOperationAction(ISD::FRINT, MVT::f32, Legal);
@@ -219,7 +219,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::FPOW, MVT::f64, Expand);
     setOperationAction(ISD::FREM, MVT::f64, Expand);
     setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
-    setOperationAction(ISD::FP_TO_FP16, MVT::f64, Expand);
+    setOperationAction(ISD::FP_TO_FP16, MVT::f64, Custom);
 
     if (Subtarget.is64Bit())
       setOperationAction(ISD::FRINT, MVT::f64, Legal);
@@ -427,6 +427,10 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
     return lowerBUILD_VECTOR(Op, DAG);
   case ISD::VECTOR_SHUFFLE:
     return lowerVECTOR_SHUFFLE(Op, DAG);
+  case ISD::FP_TO_FP16:
+    return lowerFP_TO_FP16(Op, DAG);
+  case ISD::FP16_TO_FP:
+    return lowerFP16_TO_FP(Op, DAG);
   }
   return SDValue();
 }
@@ -1354,6 +1358,40 @@ SDValue LoongArchTargetLowering::lowerVECTOR_SHUFFLE(SDValue Op,
   return SDValue();
 }
 
+SDValue LoongArchTargetLowering::lowerFP_TO_FP16(SDValue Op,
+                                                 SelectionDAG &DAG) const {
+  // Custom lower to ensure the libcall return is passed in an FPR on hard
+  // float ABIs.
+  SDLoc DL(Op);
+  MakeLibCallOptions CallOptions;
+  SDValue Op0 = Op.getOperand(0);
+  SDValue Chain = SDValue();
+  RTLIB::Libcall LC = RTLIB::getFPROUND(Op0.getValueType(), MVT::f16);
+  SDValue Res;
+  std::tie(Res, Chain) =
+      makeLibCall(DAG, LC, MVT::f32, Op0, CallOptions, DL, Chain);
+  if (Subtarget.is64Bit())
+    return DAG.getNode(LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64, Res);
+  return DAG.getBitcast(MVT::i32, Res);
+}
+
+SDValue LoongArchTargetLowering::lowerFP16_TO_FP(SDValue Op,
+                                                 SelectionDAG &DAG) const {
+  // Custom lower to ensure the libcall argument is passed in an FPR on hard
+  // float ABIs.
+  SDLoc DL(Op);
+  MakeLibCallOptions CallOptions;
+  SDValue Op0 = Op.getOperand(0);
+  SDValue Chain = SDValue();
+  SDValue Arg = Subtarget.is64Bit() ? DAG.getNode(LoongArchISD::MOVGR2FR_W_LA64,
+                                                  DL, MVT::f32, Op0)
+                                    : DAG.getBitcast(MVT::f32, Op0);
+  SDValue Res;
+  std::tie(Res, Chain) = makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg,
+                                     CallOptions, DL, Chain);
+  return Res;
+}
+
 static bool isConstantOrUndef(const SDValue Op) {
   if (Op->isUndef())
     return true;
@@ -1656,16 +1694,20 @@ SDValue LoongArchTargetLowering::lowerFP_TO_SINT(SDValue Op,
                                                  SelectionDAG &DAG) const {
 
   SDLoc DL(Op);
+  SDValue Op0 = Op.getOperand(0);
+
+  if (Op0.getValueType() == MVT::f16)
+    Op0 = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op0);
 
   if (Op.getValueSizeInBits() > 32 && Subtarget.hasBasicF() &&
       !Subtarget.hasBasicD()) {
     SDValue Dst =
-        DAG.getNode(LoongArchISD::FTINT, DL, MVT::f32, Op.getOperand(0));
+        DAG.getNode(LoongArchISD::FTINT, DL, MVT::f32, Op0);
     return DAG.getNode(LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64, Dst);
   }
 
   EVT FPTy = EVT::getFloatingPointVT(Op.getValueSizeInBits());
-  SDValue Trunc = DAG.getNode(LoongArchISD::FTINT, DL, FPTy, Op.getOperand(0));
+  SDValue Trunc = DAG.getNode(LoongArchISD::FTINT, DL, FPTy, Op0);
   return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Trunc);
 }
 
@@ -2848,6 +2890,10 @@ void LoongArchTargetLowering::ReplaceNodeResults(
     EVT FVT = EVT::getFloatingPointVT(N->getValueSizeInBits(0));
     if (getTypeAction(*DAG.getContext(), Src.getValueType()) !=
         TargetLowering::TypeSoftenFloat) {
+      if (!isTypeLegal(Src.getValueType()))
+        return;
+      if (Src.getValueType() == MVT::f16)
+        Src = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
       SDValue Dst = DAG.getNode(LoongArchISD::FTINT, DL, FVT, Src);
       Results.push_back(DAG.getNode(ISD::BITCAST, DL, VT, Dst));
       return;
@@ -4229,6 +4275,33 @@ performINTRINSIC_WO_CHAINCombine(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+static SDValue performMOVGR2FR_WCombine(SDNode *N, SelectionDAG &DAG,
+                                        TargetLowering::DAGCombinerInfo &DCI,
+                                        const LoongArchSubtarget &Subtarget) {
+  // If the input to MOVGR2FR_W_LA64 is just MOVFR2GR_S_LA64 the the
+  // conversion is unnecessary and can be replaced with the
+  // MOVFR2GR_S_LA64 operand.
+  SDValue Op0 = N->getOperand(0);
+  if (Op0.getOpcode() == LoongArchISD::MOVFR2GR_S_LA64)
+    return Op0.getOperand(0);
+  return SDValue();
+}
+
+static SDValue performMOVFR2GR_SCombine(SDNode *N, SelectionDAG &DAG,
+                                        TargetLowering::DAGCombinerInfo &DCI,
+                                        const LoongArchSubtarget &Subtarget) {
+  // If the input to MOVFR2GR_S_LA64 is just MOVGR2FR_W_LA64 then the
+  // conversion is unnecessary and can be replaced with the MOVGR2FR_W_LA64
+  // operand.
+  SDValue Op0 = N->getOperand(0);
+  MVT VT = N->getSimpleValueType(0);
+  if (Op0->getOpcode() == LoongArchISD::MOVGR2FR_W_LA64) {
+    assert(Op0.getOperand(0).getValueType() == VT && "Unexpected value type!");
+    return Op0.getOperand(0);
+  }
+  return SDValue();
+}
+
 SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
                                                    DAGCombinerInfo &DCI) const {
   SelectionDAG &DAG = DCI.DAG;
@@ -4247,6 +4320,10 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
     return performBITREV_WCombine(N, DAG, DCI, Subtarget);
   case ISD::INTRINSIC_WO_CHAIN:
     return performINTRINSIC_WO_CHAINCombine(N, DAG, DCI, Subtarget);
+  case LoongArchISD::MOVGR2FR_W_LA64:
+    return performMOVGR2FR_WCombine(N, DAG, DCI, Subtarget);
+  case LoongArchISD::MOVFR2GR_S_LA64:
+    return performMOVFR2GR_SCombine(N, DAG, DCI, Subtarget);
   }
   return SDValue();
 }
@@ -6260,3 +6337,61 @@ bool LoongArchTargetLowering::shouldAlignPointerArgs(CallInst *CI,
 
   return true;
 }
+
+bool LoongArchTargetLowering::splitValueIntoRegisterParts(
+    SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
+    unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
+  bool IsABIRegCopy = CC.has_value();
+  EVT ValueVT = Val.getValueType();
+
+  if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
+    // Cast the f16 to i16, extend to i32, pad with ones to make a float
+    // nan, and cast to f32.
+    Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val);
+    Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
+    Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
+                      DAG.getConstant(0xFFFF0000, DL, MVT::i32));
+    Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val);
+    Parts[0] = Val;
+    return true;
+  }
+
+  return false;
+}
+
+SDValue LoongArchTargetLowering::joinRegisterPartsIntoValue(
+    SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
+    MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
+  bool IsABIRegCopy = CC.has_value();
+
+  if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
+    SDValue Val = Parts[0];
+
+    // Cast the f32 to i32, truncate to i16, and cast back to f16.
+    Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
+    Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val);
+    Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
+    return Val;
+  }
+
+  return SDValue();
+}
+
+MVT LoongArchTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
+                                                           CallingConv::ID CC,
+                                                           EVT VT) const {
+  // Use f32 to pass f16.
+  if (VT == MVT::f16 && Subtarget.hasBasicF())
+    return MVT::f32;
+
+  return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
+}
+
+unsigned LoongArchTargetLowering::getNumRegistersForCallingConv(
+    LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
+  // Use f32 to pass f16.
+  if (VT == MVT::f16 && Subtarget.hasBasicF())
+    return 1;
+
+  return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
+}
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
index 6177884bd19501..5636f0d8b3d601 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
@@ -315,6 +315,8 @@ class LoongArchTargetLowering : public TargetLowering {
   SDValue lowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
 
   bool isFPImmLegal(const APFloat &Imm, EVT VT,
                     bool ForCodeSize) const override;
@@ -339,6 +341,28 @@ class LoongArchTargetLowering : public TargetLowering {
       const SmallVectorImpl<CCValAssign> &ArgLocs) const;
 
   bool softPromoteHalfType() const override { return true; }
+
+  bool
+  splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
+                              SDValue *Parts, unsigned NumParts, MVT PartVT,
+                              std::optional<CallingConv::ID> CC) const override;
+
+  SDValue
+  joinRegisterPartsIntoValue(SelectionDAG &DAG, const SDLoc &DL,
+                             const SDValue *Parts, unsigned NumParts,
+                             MVT PartVT, EVT ValueVT,
+                             std::optional<CallingConv::ID> CC) const override;
+
+  /// Return the register type for a given MVT, ensuring vectors are treated
+  /// as a series of gpr sized integers.
+  MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC,
+                                    EVT VT) const override;
+
+  /// Return the number of registers for a given MVT, ensuring vectors are
+  /// treated as a series of gpr sized integers.
+  unsigned getNumRegistersForCallingConv(LLVMContext &Context,
+                                         CallingConv::ID CC,
+                                         EVT VT) const override;
 };
 
 } // end namespace llvm
diff --git a/llvm/test/CodeGen/LoongArch/calling-conv-half.ll b/llvm/test/CodeGen/LoongArch/calling-conv-half.ll
new file mode 100644
index 00000000000000..1f825fe5b62200
--- /dev/null
+++ b/llvm/test/CodeGen/LoongArch/calling-conv-half.ll
@@ -0,0 +1,924 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc --mtriple=loongarch32 --verify-machineinstrs < %s | FileCheck %s --check-prefix=LA32S
+; RUN: llc --mtriple=loongarch32 --mattr=+f --verify-machineinstrs < %s | FileCheck %s --check-prefix=LA32F
+; RUN: llc --mtriple=loongarch32 --mattr=+d --verify-machineinstrs < %s | FileCheck %s --check-prefix=LA32D
+; RUN: llc --mtriple=loongarch64 --verify-machineinstrs < %s | FileCheck %s --check-prefix=LA64S
+; RUN: llc --mtriple=loongarch64 --mattr=+f --verify-machineinstrs < %s | FileCheck %s --check-prefix=LA64F
+; RUN: llc --mtriple=loongarch64 --mattr=+d --verify-machineinstrs < %s | FileCheck %s --check-prefix=LA64D
+
+define i32 @callee_half_in_fregs(i32 %a, i32 %b, i32 %c, i32 %d, i32 %e, i32 %f, i32 %g, i32 %h, half %i) nounwind {
+; LA32S-LABEL: callee_half_in_fregs:
+; LA32S:       # %bb.0:
+; LA32S-NEXT:    addi.w $sp, $sp, -16
+; LA32S-NEXT:    st.w $ra, $sp, 12 # 4-byte Folded Spill
+; LA32S-NEXT:    st.w $fp, $sp, 8 # 4-byte Folded Spill
+; LA32S-NEXT:    ld.hu $a1, $sp, 16
+; LA32S-NEXT:    move $fp, $a0
+; LA32S-NEXT:    move $a0, $a1
+; LA32S-NEXT:    bl %plt(__gnu_h2f_ieee)
+; LA32S-NEXT:    bl %plt(__fixsfsi)
+; LA32S-NEXT:    add.w $a0, $fp, $a0
+; LA32S-NEXT:    ld.w $fp, $sp, 8 # 4-byte Folded Reload
+; LA32S-NEXT:    ld.w $ra, $sp, 12 # 4-byte Folded Reload
+; LA32S-NEXT:    addi.w $sp, $sp, 16
+; LA32S-NEXT:    ret
+;
+; LA32F-LABEL: callee_half_in_fregs:
+; LA32F:       # %bb.0:
+; LA32F-NEXT:    addi.w $sp, $sp, -16
+; LA32F-NEXT:    st.w $ra, $sp, 12 # 4-byte Folded Spill
+; LA32F-NEXT:    st.w $fp, $sp, 8 # 4-byte Folded Spill
+; LA32F-NEXT:    move $fp, $a0
+; LA32F-NEXT:    bl %plt(__gnu_h2f_ieee)
+; LA32F-NEXT:    ftintrz.w.s $fa0, $fa0
+; LA32F-NEXT:    movfr2gr.s $a0, $fa0
+; LA32F-NEXT:    add.w $a0, $fp, $a0
+; LA32F-NEXT:    ld.w $fp, $sp, 8 # 4-byte Folded Reload
+; LA32F-NEXT:    ld.w $ra, $sp, 12 # 4-byte Folded Reload
+; LA32F-NEXT:    addi.w $sp, $sp, 16
+; LA32F-NEXT:    ret
+;
+; LA32D-LABEL: callee_half_in_fregs:
+; LA32D:       # %bb.0:
+; LA32D-NEXT:    addi.w $sp, $sp, -16
+; LA32D-NEXT:    st.w $ra, $sp, 12 # 4-byte Folded Spill
+; LA32D-NEXT:    st.w $fp, $sp, 8 # 4-byte Folded Spill
+; LA32D-NEXT:    move $fp, $a0
+; LA32D-NEXT:    bl %plt(__gnu_h2f_ieee)
+; LA32D-NEXT:    ftintrz.w.s $fa0, $fa0
+; LA32D-NEXT:    movfr2gr.s $a0, $fa0
+; LA32D-NEXT:    add.w $a0, $fp, $a0
+; LA32D-NEXT:    ld.w $fp, $sp, 8 # 4-byte Folded Reload
+; LA32D-NEXT:    ld.w $ra, $sp, 12 # 4-byte Folded Reload
+; LA32D-NEXT:    addi.w $sp, $sp, 16
+; LA32D-NEXT:    ret
+;
+; LA64S-LABEL: callee_half_in_fregs:
+; LA64S:       # %bb.0:
+; LA64S-NEXT:    addi.d $sp, $sp, -16
+; LA64S-NEXT:    st.d $ra, $sp, 8 # 8-byte Folded Spill
+; LA64S-NEXT:    st.d $fp, $sp, 0 # 8-byte Folded Spill
+; LA64S-NEXT:    ld.hu $a1, $sp, 16
+; LA64S-NEXT:    move $fp, $a0
+; LA64S-NEXT:    move $a0, $a1
+; LA64S-NEXT:    bl %plt(__gnu_h2f_ieee)
+; LA64S-NEXT:    bl %plt(__fixsfdi)
+; LA64S-NEXT:    add.w $a0, $fp, $a0
+; LA64S-NEXT:    ld.d $fp, $sp, 0 # 8-byte Folded Reload
+; LA64S-NEXT:    ld.d $ra, $sp, 8 # 8-byte Folded Reload
+; LA64S-NEXT:    addi.d $sp, $sp, 16
+; LA64S-NEXT:    ret
+;
+; LA64F-LABEL: callee_half_in_fregs:
+; LA64F:       # %bb.0:
+; LA64F-NEXT:    addi.d $sp, $sp, -16
+; LA64F-NEXT:    st.d $ra, $sp, 8 # 8-byte Folded Spill
+; LA64F-NEXT:    st.d $fp, $sp, 0 # 8-byte Folded Spill
+; LA64F-NEXT:    move $fp, $a0
+; LA64F-NEXT:    bl %plt(__gnu_h2f_ieee)
+; LA64F-NEXT:    ftintrz.w.s $fa0, $fa0
+; LA64F-NEXT:    movfr2gr.s $a0, $fa0
+; LA64F-NEXT:    add.w $a0, $fp, $a0
+; LA64F-NEXT:    ld.d $fp, $sp, 0 # 8-byte Folded Reload
+; LA64F-NEXT:    ld.d $ra, $sp, 8 # 8-byte Folded Reload
+; LA64F-NEXT:    addi.d $sp, $sp, 16
+; LA64F-NEXT:    ret
+;
+; LA64D-LABEL: callee_half_in_fregs:
+; LA64D:       # %bb.0:
+; LA64D-NEXT:    addi.d $sp, $sp, -16
+; LA64D-NEXT:    st.d $ra, $sp, 8 # 8-byte Folded Spill
+; LA64D-NEXT:    st.d $fp, $sp, 0 # 8-byte Folded Spill
+; LA64D-NEXT:    move $fp, $a0
+; LA64D-NEXT:    bl %plt(__gnu_h2f_ieee)
+; LA64D-NEXT:    ftintrz.l.s $fa0, $fa0
+; LA64D-NEXT:    movfr2gr.d $a0, $fa0
+; LA64D-NEXT:    add.w $a0, $fp, $a0
+; LA64D-NEXT:    ld.d $fp, $sp, 0 # 8-byte Folded Reload
+; LA64D-NEXT:    ld.d $ra, $sp, 8 # 8-byte Folded Reload
+; LA64D-NEXT:    addi.d $sp, $sp, 16
+; LA64D-NEXT:    ret
+  %1 = fptosi half %i to i32
+  %2 = add i32 %a, %1
+  ret i32 %2
+}
+
+define i32 @caller_half_in_fregs() nounwind {
+; LA32S-LABEL: caller_half_in_fregs:
+; LA32S:       # %bb.0:
+; LA32S-NEXT:    addi.w $sp, $sp, -16
+; LA32S-NEXT:    st.w $ra, $sp, 12 # 4-byte Folded Spill
+; LA32S-NEXT:    lu12i.w $t0, 4
+; LA32S-NEXT:    ori $a0, $zero, 1
+; LA32S-NEXT:    ori $a1, $zero, 2
+; LA32S-NEXT:    ori $a2, $zero, 3
+; LA32S-NEXT:    ori $a3, $zero, 4
+; LA32S-NEXT:    ori $a4, $zero, 5
+; LA32S-NEXT:    ori $a5, $zero, 6
+; LA32S-NEXT:    ori $a6, $zero, 7
+; LA32S-NEXT:    ori $a7, $zero, 8
+; LA32S-NEXT:    st.w $t0, $sp, 0
+; LA32S-NEXT:    bl %plt(callee_half_in_fregs)
+; LA32S-NEXT:    ld.w $ra, $sp, 12 # 4-byte Folded Reload
+; LA32S-NEXT:    addi.w $sp, $sp, 16
+; LA32S-NEXT:    ret
+;
+; LA32F-LABEL: caller_half_in_fregs:
+; LA32F:       # %bb.0:
+; LA32F-NEXT:    addi.w $sp, $sp, -16
+; LA32F-NEXT:    st.w $ra, $sp, 12 # 4-byte Folded Spill
+; LA32F-NEXT:    pcalau12i $a0, %pc_hi20(.LCPI1_0)
+; LA32F-NEXT:    fld.s $fa0, $a0, %pc_lo12(.LCPI1_0)
+; LA32F-NEXT:    ori $a0, $zero, 1
+; LA32F-NEXT:    ori $a1, $zero, 2
+; LA32F-NEXT:    ori $a2, $zero, 3
+; LA32F-NEXT:    ori $a3, $zero, 4
+; LA32F-NEXT:    ori $a4, $zero, 5
+; LA32F-NEXT:    ori $a5, $zero, 6
+; LA32F-NEXT:    ori $a6, $zero, 7
+; LA32F-NEXT:    ori $a7, $zero, 8
+; LA32F-NEXT:    bl %plt(callee_half_in_fregs)
+; LA32F-NEXT:    ld.w $ra, $sp, 12 # 4-byte Folded Reload
+; LA32F-NEXT:    addi.w $sp, $sp, 16
+; LA32F-NEXT:    ret
+;
+; LA32D-LABEL: caller_half_in_fregs:
+; LA32D:       # %bb.0:
+; LA32D-NEXT:    addi.w $sp, $sp, -16
+; LA32D-NEXT:    st.w $ra, $sp, 12 # 4-byte Folded Spill
+; LA32D-NEXT:    pcalau12i $a0, %pc_hi20(.LCPI1_0)
+; LA32D-NEXT:    fld.s $fa0, $a0, %pc_lo12(.LCPI1_0)
+; LA32D-NEXT:    ori $a0, $zero, 1
+; LA32D-NEXT:    ori $a1, $zero, 2
+; LA32D-NEXT:    ori $a2, $zero, 3
+; LA32D-NEXT:    ori $a3, $zero, 4
+; LA32D-NEXT:    ori $a4, $zero, 5
+; LA32D-NEXT:    ori $a5, $zero, 6
+; LA32D-NEXT:    ori $a6, $zero, 7
+; LA32D-NEXT:    ori $a7, $zero, 8
+; LA32D-NEXT:    bl %plt(callee_half_in_fregs)
+; LA32D-NEXT:    ld.w $ra, $sp, 12 # 4-byte Folded Reload
+; LA32D-NEXT:    addi.w $sp, $sp, 16
+; LA32D-NEXT:    ret
+;
+; LA64S-LABEL: caller_half_in_fregs:
+; LA64S:       # %bb.0:
+; LA64S-NEXT:    addi.d $sp, $sp, -16
+; LA64S-NEXT:    st.d $ra, $sp, 8 # 8-byte Folded Spill
+; LA64S-NEXT:    lu12i.w $t0, 4
+; LA64S-NEXT:    ori $a0, $zero, 1
+; LA64S-NEXT:    ori $a1, $zero, 2
+; LA64S-NEXT:    ori $a2, $zero, 3
+; LA64S-NEXT:    ori $a3, $zero, 4
+; LA64S-NEXT:    ori $a4, $zero, 5
+; LA64S-NEXT:    ori $a5, $zero, 6
+; LA64S-NEXT:    ori $a6, $zero, 7
+; LA64S-NEXT:    ori $a7, $zero, 8
+; LA64S-NEXT:    st.d $t0, $sp, 0
+; LA64S-NEXT:    bl %plt(callee_half_in_fregs)
+; LA64S-NEXT:    ld.d $ra, $sp, 8 # 8-byte Folded Reload
+; LA64S-NEXT:    addi.d $sp, $sp, 16
+; LA64S-NEXT:    ret
+;
+; LA64F-LABEL: caller_half_in_fregs:
+; LA64F:       # %bb.0:
+; LA64F-NEXT:    addi.d $sp, $sp, -16
+; LA64F-NEXT:    st.d $ra, $sp, 8 # 8-byte Folded Spill
+; LA64F-NEXT:    pcalau12i $a0, %pc_hi20(.LCPI1_0)
+; LA64F-NEXT:    fld.s $fa0, $a0, %pc_lo12(.LCPI1_0)
+; LA64F-NEXT:    ori $a0, $zero, 1
+; LA64F-NEXT:    ori $a1, $zero, 2
+; LA64F-NEXT:    ori $a2, $zero, 3
+; LA64F-NEXT:    ori $a3, $zero, 4
+; LA64F-NEXT:    ori $a4, $zero, 5
+; LA64F-NEXT:    ori $a5, $zero, 6
+; LA64F-NEXT:    ori $a6, $zero, 7
+; LA64F-NEXT:    ori $a7, $zero, 8
+; LA64F-NEXT:    bl %plt(callee_half_in_fregs)
+; LA64F-NEXT:    ld.d $ra, $sp, 8 # 8-byte Folded Reload
+; LA64F-NEXT:    addi.d $sp, $sp, 16
+; LA64F-NEXT:    ret
+;
+; LA64D-LABEL: caller_half_in_fregs:
+; LA64D:       # %bb.0:
+; LA64D-NEXT:    addi.d $sp, $sp, -16
+; LA64D-NEXT:    st.d $ra, $sp, 8 # 8-byte Folded Spill
+; LA64D-NEXT:    pcalau12i $a0, %pc_hi20(.LCPI1_0)
+; LA64D-NEXT:    fld.s $fa0, $a0, %pc_lo12(.LCPI1_0)
+; LA64D-NEXT:    ori $a0, $zero, 1
+; LA64D-NEXT:    ori $a1, $zero, 2
+; LA64D-NEXT:    ori $a2, $zero, 3
+; LA64D-NEXT:    ori $a3, $zero, 4
+; LA64D-NEXT:    ori $a4, $zero, 5
+; LA64D-NEXT:    ori $a5, $zero, 6
+; LA64D-NEXT:    ori $a6, $zero, 7
+; LA64D-NEXT:    ori $a7, $zero, 8
+; LA64D-NEXT:    bl %plt(callee_half_in_fregs)
+; LA64D-NEXT:    ld.d $ra, $sp, 8 # 8-byte Folded Reload
+; LA64D-NEXT:    addi.d $sp, $sp, 16
+; LA64D-NEXT:    ret
+  %1 = call i32 @callee_half_in_fregs(i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, half 2.0)
+  ret i32 %1
+}
+
+define i32 @callee_half_in_gregs(half %a, half %b, half %c, half %d, half %...
[truncated]

@github-actions
Copy link

github-actions bot commented Sep 20, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

…xtension is enabled

LoongArch currently lacks a hardware extension for the fp16 data type, and the
ABI documentation does not explicitly define how to handle fp16. Future revsions
of the LoongArch specification will include conventions to address fp16 requirements.

Previously, we maintained the 'half' type in its 16-bit format between operations.
Regardless of whether the F extension is enabled, the value would be passed in the
lower 16 bits of a GPR in its 'half' format.

With this patch, depending on the ABI in use, the value will be passed either in
an FPR or a GPR in 'half' format. This ensures consistency with the bits location
when the fp16 hardware extension is enabled.
@heiher
Copy link
Member Author

heiher commented Sep 20, 2024

cc @xen0n @xry111

Comment on lines +1380 to +1381
// Custom lower to ensure the libcall argument is passed in an FPR on hard
// float ABIs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need to custom lower the casts to change the ABI. Are you trying to special case the ABI for this one call in this one instance? That seems bad

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not trying to special case the ABI, but rather ensuring compliance with the ABI rules for floating-point operations. Specifically, the argument for f32 __gnu_h2f_ieee(f16) needs to be passed via the FPR, as per the floating-point ABI, rather than the GPR. Custom lowering ensures that the argument is correctly passed through the FPR in cases where the default behavior doesn't align with this requirement. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change does not accomplish this. The cast opcodes have nothing to do with the ABI, other than ABI code may result in inserting them

Copy link
Member Author

@heiher heiher Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default behavior of softPromoteHalf does not align with the expectations of the architecture, the custom lowering as referenced by the approach used in RISC-V.

Ref: https://reviews.llvm.org/D151284

This legalisation produces ISD::FP_TO_FP16 and ISD::FP16_TO_FP nodes which (as described in ISDOpcodes.h) provide a "semi-softened interface for dealing with f16 (as an i16)". i.e. the return type of the FP_TO_FP16 is an integer rather than a float (and the arg of FP16_TO_FP is an integer). The remainder of the description focuses primarily on FP_TO_FP16 for ease of explanation.

In Rust's implementation, the argument of __gnu_h2f_ieee is f16, not i16.

https://github.com/rust-lang/compiler-builtins/blob/compiler_builtins-v0.1.126/src/float/extend.rs#L95-L97

pub extern "C" fn __gnu_h2f_ieee(a: f16) -> f32 {
    extend(a)
}

Is there a better approach to achieve this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words, you are special casing the ABI for this one libcall that happens to be used for legalization of the conversions. If you are fixing the ABI, as the title suggests, you need to make changes to the calling convention lowering code and possibly use addRegisterClass for f16.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default the half is going to be be pre-promoted to a legal f32 type and you need to intervene before that. I would start by overriding getRegisterTypeForCallingConv, and then see how that goes. You may need to just custom hack on the argument lists in each of the Lower* functions.

Yes SelectionDAG makes this more difficult than it should be. It would be much easier if calling convention code operated on the raw IR signature instead of going through type legalization first

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your guidance. I haven't quite managed to get it done yet. To pass fp16 in FPR (excluding libcall), splitValueIntoRegisterParts and joinRegisterPartsIntoValue insert all ones in the upper bits and extract the lower bits by casting to an integer type. I guess this is the key point where the fp16 value is promoted to an integer when it reaches the FP16_TO_FP operation, but I'm not sure how to bypass the integer type to achieve this.

Additionally, it seems that custom-lowering FP16_TO_FP and FP_TO_FP16 to generate a libcall while keeping it passed in FPR works quite well and is fairly easy to implement, RISC-V is already using this approach. Can we go ahead with this? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really correct. It only happens to work out, and MOVFR2GR_S_LA64 is a hack

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Rust's implementation, the argument of __gnu_h2f_ieee is f16, not i16.

https://github.com/rust-lang/compiler-builtins/blob/compiler_builtins-v0.1.126/src/float/extend.rs#L95-L97

pub extern "C" fn __gnu_h2f_ieee(a: f16) -> f32 {
    extend(a)
}

Just to address this (quite late) - aiui __gnu_h2f_ieee is only called on platforms where f16 is passed as an integer so that doesn't matter here. If there is a float ABI then __extendhfsf2/__truncsfhf2 is used instead.

There is something weird with the return value though, Rust's compiler-builtins and LLVM's compiler-rt both use f32 but GCC uses int. (I think maybe GCC just never emits that libcall except on ARM, LLVM seems to use it a lot more).

Copy link
Member Author

@heiher heiher Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my view, __gnu_h2f_ieee is primarily designed for converting f16 values to f32, particularly in scenarios where hardware lacks native f16 support, requiring software emulation instead. This function does not inherently define any specific argument-passing conventions; rather, those are determined by the architecture's ABI. For instance, on RISC-V, the lower 16 bits of a floating point argument register are used (given that hardware support for f16 is not enabled), while some other architectures use integer registers. In terms of implementation, __extendhfsf2 serves as an alias for __gnu_h2f_ieee and are used by arm.

https://github.com/rust-lang/compiler-builtins/blob/compiler_builtins-v0.1.126/src/float/extend.rs#L87-L89

pub extern "C" fn __extendhfsf2(a: f16) -> f32 {
    extend(a)
}

https://github.com/rust-lang/compiler-builtins/blob/compiler_builtins-v0.1.126/src/float/extend.rs#L95-L97

pub extern "C" fn __gnu_h2f_ieee(a: f16) -> f32 {
    extend(a)
}

https://github.com/llvm/llvm-project/blob/llvmorg-19.1.2/compiler-rt/lib/builtins/extendhfsf2.c#L13-L19

// Use a forwarding definition and noinline to implement a poor man's alias,
// as there isn't a good cross-platform way of defining one.
COMPILER_RT_ABI NOINLINE float __extendhfsf2(src_t a) {
  return __extendXfYf2__(a);
}

COMPILER_RT_ABI float __gnu_h2f_ieee(src_t a) { return __extendhfsf2(a); }

@xry111
Copy link
Contributor

xry111 commented Sep 20, 2024

We don't have an "F extension." That's a concept of RISC-V and we call it "basic floating-point instructions."

And the ABI isn't necessarily aligning to the ISA capability (i.e. you can still use LP64S ABI on a CPU with floating-point instructions) so it's better to re-title this Pass 'half' in the lower 16 bits of an f32 value with F/D ABI.

@heiher heiher changed the title [LoongArch] Pass 'half' in the lower 16 bits of an f32 value when F extension is enabled [LoongArch] Pass 'half' in the lower 16 bits of an f32 value with F/D ABI Sep 20, 2024
@nikic
Copy link
Contributor

nikic commented Sep 20, 2024

Are you aware of #97981? This was recently fixed by #107791 and it sounds like you want to un-fix it? It's my understanding that for half it's preferred to use GPR rather than FPR for the call ABI, if you do not have ABI requirements to the contrary.

Note that there is also the useFPRegsForHalfType() hook.

@heiher
Copy link
Member Author

heiher commented Sep 20, 2024

Are you aware of #97981? This was recently fixed by #107791 and it sounds like you want to un-fix it? It's my understanding that for half it's preferred to use GPR rather than FPR for the call ABI, if you do not have ABI requirements to the contrary.

Note that there is also the useFPRegsForHalfType() hook.

This doesn't undo the fix. The motivation is to prepare for future LoongArch FPU support for fp16 (similar to RISC-V's Zfh extension). If we pass fp16 via GPR now, but future hardware uses the lower 16 bits of FPR for fp16, we'd introduce unnecessary data transfers between GPR and FPR. Passing software fp16 via FPR now aligns the ABI with future hardware fp16, avoiding issues similar to -mfloat-abi=softfp. Given software fp16's low performance, I'd recommend prioritizing future hardware fp16 support. Thanks.

@heiher
Copy link
Member Author

heiher commented Sep 20, 2024

Note that there is also the useFPRegsForHalfType() hook.

I'll try it, Thanks.

Update: It doesn't work using FPR for libcall of FP16_TO_FP/FP_TO_FP16.

@heiher
Copy link
Member Author

heiher commented Oct 19, 2024

Update: The design team behind the LoongArch ISA currently has no plans to add fp16 hardware extensions. Instead, they are leaning towards leveraging GPU and other hardware acceleration modules for fp16 computations. As a result, software calling conventions will continue to pass fp16 data using general-purpose registers.

I'll be reopening PR #109093 to ensure that fp16 support is functional on the release/19.x.

Thanks, everyone!

@heiher heiher closed this Oct 19, 2024
@heiher heiher deleted the fp16 branch October 19, 2024 13:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants