@@ -711,6 +711,13 @@ SystemZTargetLowering::SystemZTargetLowering(const TargetMachine &TM,
711
711
setOperationAction (ISD::BITCAST, MVT::f32, Custom);
712
712
}
713
713
714
+ // Expand FP16 <=> FP32 conversions to libcalls and handle FP16 loads and
715
+ // stores in GPRs.
716
+ setOperationAction (ISD::FP16_TO_FP, MVT::f32, Expand);
717
+ setOperationAction (ISD::FP_TO_FP16, MVT::f32, Expand);
718
+ setLoadExtAction (ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
719
+ setTruncStoreAction (MVT::f32, MVT::f16, Expand);
720
+
714
721
// VASTART and VACOPY need to deal with the SystemZ-specific varargs
715
722
// structure, but VAEND is a no-op.
716
723
setOperationAction (ISD::VASTART, MVT::Other, Custom);
@@ -784,6 +791,20 @@ bool SystemZTargetLowering::useSoftFloat() const {
784
791
return Subtarget.hasSoftFloat ();
785
792
}
786
793
794
+ MVT SystemZTargetLowering::getRegisterTypeForCallingConv (
795
+ LLVMContext &Context, CallingConv::ID CC,
796
+ EVT VT) const {
797
+ // 128-bit single-element vector types are passed like other vectors,
798
+ // not like their element type.
799
+ if (VT.isVector () && VT.getSizeInBits () == 128 &&
800
+ VT.getVectorNumElements () == 1 )
801
+ return MVT::v16i8;
802
+ // Keep f16 so that they can be recognized and handled.
803
+ if (VT == MVT::f16)
804
+ return MVT::f16;
805
+ return TargetLowering::getRegisterTypeForCallingConv (Context, CC, VT);
806
+ }
807
+
787
808
EVT SystemZTargetLowering::getSetCCResultType (const DataLayout &DL,
788
809
LLVMContext &, EVT VT) const {
789
810
if (!VT.isVector ())
@@ -1597,6 +1618,15 @@ bool SystemZTargetLowering::splitValueIntoRegisterParts(
1597
1618
return true ;
1598
1619
}
1599
1620
1621
+ // Convert f16 to f32 (Out-arg).
1622
+ if (PartVT == MVT::f16) {
1623
+ assert (NumParts == 1 && " " );
1624
+ SDValue I16Val = DAG.getBitcast (MVT::i16, Val);
1625
+ SDValue I32Val = DAG.getAnyExtOrTrunc (I16Val, DL, MVT::i32);
1626
+ Parts[0 ] = DAG.getBitcast (MVT::f32, I32Val);
1627
+ return true ;
1628
+ }
1629
+
1600
1630
return false ;
1601
1631
}
1602
1632
@@ -1612,6 +1642,18 @@ SDValue SystemZTargetLowering::joinRegisterPartsIntoValue(
1612
1642
return SDValue ();
1613
1643
}
1614
1644
1645
+ // F32Val holds a f16 value in f32, return it as an f16 (In-arg). The
1646
+ // CopyFromReg was made into an f32 as required as FP32 registers are used
1647
+ // for arguments, now convert it to f16.
1648
+ static SDValue convertF32ToF16 (SDValue F32Val, SelectionDAG &DAG,
1649
+ const SDLoc &DL) {
1650
+ assert (F32Val->getOpcode () == ISD::CopyFromReg &&
1651
+ " Only expecting to handle f16 with CopyFromReg here." );
1652
+ SDValue I32Val = DAG.getBitcast (MVT::i32, F32Val);
1653
+ SDValue I16Val = DAG.getAnyExtOrTrunc (I32Val, DL, MVT::i16);
1654
+ return DAG.getBitcast (MVT::f16, I16Val);
1655
+ }
1656
+
1615
1657
SDValue SystemZTargetLowering::LowerFormalArguments (
1616
1658
SDValue Chain, CallingConv::ID CallConv, bool IsVarArg,
1617
1659
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
@@ -1651,6 +1693,7 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
1651
1693
NumFixedGPRs += 1 ;
1652
1694
RC = &SystemZ::GR64BitRegClass;
1653
1695
break ;
1696
+ case MVT::f16:
1654
1697
case MVT::f32:
1655
1698
NumFixedFPRs += 1 ;
1656
1699
RC = &SystemZ::FP32BitRegClass;
@@ -1675,7 +1718,11 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
1675
1718
1676
1719
Register VReg = MRI.createVirtualRegister (RC);
1677
1720
MRI.addLiveIn (VA.getLocReg (), VReg);
1678
- ArgValue = DAG.getCopyFromReg (Chain, DL, VReg, LocVT);
1721
+ // Special handling is needed for f16.
1722
+ MVT ArgVT = VA.getLocVT () == MVT::f16 ? MVT::f32 : VA.getLocVT ();
1723
+ ArgValue = DAG.getCopyFromReg (Chain, DL, VReg, ArgVT);
1724
+ if (VA.getLocVT () == MVT::f16)
1725
+ ArgValue = convertF32ToF16 (ArgValue, DAG, DL);
1679
1726
} else {
1680
1727
assert (VA.isMemLoc () && " Argument not register or memory" );
1681
1728
@@ -1695,9 +1742,12 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
1695
1742
// from this parameter. Unpromoted ints and floats are
1696
1743
// passed as right-justified 8-byte values.
1697
1744
SDValue FIN = DAG.getFrameIndex (FI, PtrVT);
1698
- if (VA.getLocVT () == MVT::i32 || VA.getLocVT () == MVT::f32)
1745
+ if (VA.getLocVT () == MVT::i32 || VA.getLocVT () == MVT::f32 ||
1746
+ VA.getLocVT () == MVT::f16) {
1747
+ unsigned SlotOffs = VA.getLocVT () == MVT::f16 ? 6 : 4 ;
1699
1748
FIN = DAG.getNode (ISD::ADD, DL, PtrVT, FIN,
1700
- DAG.getIntPtrConstant (4 , DL));
1749
+ DAG.getIntPtrConstant (SlotOffs, DL));
1750
+ }
1701
1751
ArgValue = DAG.getLoad (LocVT, DL, Chain, FIN,
1702
1752
MachinePointerInfo::getFixedStack (MF, FI));
1703
1753
}
@@ -2120,10 +2170,14 @@ SystemZTargetLowering::LowerCall(CallLoweringInfo &CLI,
2120
2170
// Copy all of the result registers out of their specified physreg.
2121
2171
for (CCValAssign &VA : RetLocs) {
2122
2172
// Copy the value out, gluing the copy to the end of the call sequence.
2173
+ // Special handling is needed for f16.
2174
+ MVT ArgVT = VA.getLocVT () == MVT::f16 ? MVT::f32 : VA.getLocVT ();
2123
2175
SDValue RetValue = DAG.getCopyFromReg (Chain, DL, VA.getLocReg (),
2124
- VA. getLocVT () , Glue);
2176
+ ArgVT , Glue);
2125
2177
Chain = RetValue.getValue (1 );
2126
2178
Glue = RetValue.getValue (2 );
2179
+ if (VA.getLocVT () == MVT::f16)
2180
+ RetValue = convertF32ToF16 (RetValue, DAG, DL);
2127
2181
2128
2182
// Convert the value of the return register into the value that's
2129
2183
// being returned.
0 commit comments