Skip to content

Commit 93915f8

Browse files
author
KristofferC
committed
Emit aliases to FP16 conversion routines
1 parent 21a2c43 commit 93915f8

File tree

9 files changed

+98
-36
lines changed

9 files changed

+98
-36
lines changed

src/APInt-C.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
316316
void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
317317
double Val;
318318
if (numbits == 16)
319-
Val = __gnu_h2f_ieee(*(uint16_t*)pa);
319+
Val = julia__gnu_h2f_ieee(*(uint16_t*)pa);
320320
else if (numbits == 32)
321321
Val = *(float*)pa;
322322
else if (numbits == 64)
@@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
391391
val = a.roundToDouble(true);
392392
}
393393
if (onumbits == 16)
394-
*(uint16_t*)pr = __gnu_f2h_ieee(val);
394+
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
395395
else if (onumbits == 32)
396396
*(float*)pr = val;
397397
else if (onumbits == 64)
@@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
408408
val = a.roundToDouble(false);
409409
}
410410
if (onumbits == 16)
411-
*(uint16_t*)pr = __gnu_f2h_ieee(val);
411+
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
412412
else if (onumbits == 32)
413413
*(float*)pr = val;
414414
else if (onumbits == 64)

src/aotcompile.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include <llvm/Support/CodeGen.h>
5252
#endif
5353

54+
#include <llvm/IR/IRBuilder.h>
5455
#include <llvm/IR/LegacyPassManagers.h>
5556
#include <llvm/Transforms/Utils/Cloning.h>
5657

@@ -276,6 +277,24 @@ static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance
276277
*ci_out = codeinst;
277278
}
278279

280+
static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
281+
{
282+
Function *target = M.getFunction(alias);
283+
if (!target) {
284+
target = Function::Create(FT, Function::ExternalLinkage, alias, M);
285+
}
286+
// Weak so that this does not get discarded
287+
// maybe use llvm.compiler.used instead?
288+
Function *interposer = Function::Create(FT, Function::WeakAnyLinkage, name, M);
289+
290+
llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", interposer));
291+
SmallVector<Value *, 4> CallArgs;
292+
for (auto &arg : interposer->args())
293+
CallArgs.push_back(&arg);
294+
auto val = builder.CreateCall(target, CallArgs);
295+
builder.CreateRet(val);
296+
}
297+
279298
// takes the running content that has collected in the shadow module and dump it to disk
280299
// this builds the object file portion of the sysimage files for fast startup, and can
281300
// also be used be extern consumers like GPUCompiler.jl to obtain a module containing
@@ -554,6 +573,20 @@ void jl_dump_native(void *native_code,
554573
"jl_RTLD_DEFAULT_handle_pointer"));
555574
}
556575

576+
// We would like to emit an alias or an weakref alias to redirect these symbols
577+
// but LLVM doesn't let us emit a GlobalAlias to a declaration...
578+
// So for now we inject a definition of these functions that calls our runtime functions.
579+
injectCRTAlias(*data->M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee",
580+
FunctionType::get(Type::getFloatTy(Context), { Type::getHalfTy(Context) }, false));
581+
injectCRTAlias(*data->M, "__extendhfsf2", "julia__gnu_h2f_ieee",
582+
FunctionType::get(Type::getFloatTy(Context), { Type::getHalfTy(Context) }, false));
583+
injectCRTAlias(*data->M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee",
584+
FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false));
585+
injectCRTAlias(*data->M, "__truncsfhf2", "julia__gnu_f2h_ieee",
586+
FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false));
587+
injectCRTAlias(*data->M, "__truncdfhf2", "julia__truncdfhf2",
588+
FunctionType::get(Type::getHalfTy(Context), { Type::getDoubleTy(Context) }, false));
589+
557590
// do the actual work
558591
auto add_output = [&] (Module &M, StringRef unopt_bc_Name, StringRef bc_Name, StringRef obj_Name, StringRef asm_Name) {
559592
PM.run(M);

src/intrinsics.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,22 +1476,17 @@ static inline uint16_t float_to_half(float param)
14761476

14771477
#if !defined(_OS_DARWIN_) // xcode already links compiler-rt
14781478

1479-
extern "C" JL_DLLEXPORT float __gnu_h2f_ieee(uint16_t param)
1479+
extern "C" JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param)
14801480
{
14811481
return half_to_float(param);
14821482
}
14831483

1484-
extern "C" JL_DLLEXPORT float __extendhfsf2(uint16_t param)
1485-
{
1486-
return half_to_float(param);
1487-
}
1488-
1489-
extern "C" JL_DLLEXPORT uint16_t __gnu_f2h_ieee(float param)
1484+
extern "C" JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param)
14901485
{
14911486
return float_to_half(param);
14921487
}
14931488

1494-
extern "C" JL_DLLEXPORT uint16_t __truncdfhf2(double param)
1489+
extern "C" JL_DLLEXPORT uint16_t julia__truncdfhf2(double param)
14951490
{
14961491
return float_to_half((float)param);
14971492
}

src/jitlayers.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,12 +737,26 @@ JuliaOJIT::JuliaOJIT(TargetMachine &TM, LLVMContext *LLVMCtx)
737737
}
738738

739739
JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
740+
741+
orc::SymbolAliasMap jl_crt = {
742+
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
743+
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
744+
{ mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
745+
{ mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
746+
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
747+
};
748+
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
740749
}
741750

742-
void JuliaOJIT::addGlobalMapping(StringRef Name, uint64_t Addr)
751+
orc::SymbolStringPtr JuliaOJIT::mangle(StringRef Name)
743752
{
744753
std::string MangleName = getMangledName(Name);
745-
cantFail(JD.define(orc::absoluteSymbols({{ES.intern(MangleName), JITEvaluatedSymbol::fromPointer((void*)Addr)}})));
754+
return ES.intern(MangleName);
755+
}
756+
757+
void JuliaOJIT::addGlobalMapping(StringRef Name, uint64_t Addr)
758+
{
759+
cantFail(JD.define(orc::absoluteSymbols({{mangle(Name), JITEvaluatedSymbol::fromPointer((void*)Addr)}})));
746760
}
747761

748762
void JuliaOJIT::addModule(std::unique_ptr<Module> M)

src/jitlayers.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class JuliaOJIT {
185185
const object::ObjectFile &Obj,
186186
const RuntimeDyld::LoadedObjectInfo &LoadedObjectInfo);
187187
#endif
188+
orc::SymbolStringPtr mangle(StringRef Name);
188189
void addGlobalMapping(StringRef Name, uint64_t Addr);
189190
void addModule(std::unique_ptr<Module> M);
190191
#if JL_LLVM_VERSION < 120000

src/julia.expmap

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,6 @@
4242
environ;
4343
__progname;
4444

45-
/* compiler run-time intrinsics */
46-
__gnu_h2f_ieee;
47-
__extendhfsf2;
48-
__gnu_f2h_ieee;
49-
__truncdfhf2;
50-
5145
local:
5246
*;
5347
};

src/julia_internal.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,8 +1363,9 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
13631363
#define JL_GC_ASSERT_LIVE(x) (void)(x)
13641364
#endif
13651365

1366-
float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
1367-
uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
1366+
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
1367+
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
1368+
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
13681369

13691370
#ifdef __cplusplus
13701371
}

src/runtime_intrinsics.c

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ static inline unsigned select_by_size(unsigned sz) JL_NOTSAFEPOINT
169169
}
170170

171171
#define fp_select(a, func) \
172-
sizeof(a) == sizeof(float) ? func##f((float)a) : func(a)
172+
sizeof(a) <= sizeof(float) ? func##f((float)a) : func(a)
173173
#define fp_select2(a, b, func) \
174-
sizeof(a) == sizeof(float) ? func##f(a, b) : func(a, b)
174+
sizeof(a) <= sizeof(float) ? func##f(a, b) : func(a, b)
175175

176176
// fast-function generators //
177177

@@ -215,11 +215,11 @@ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
215215
static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
216216
{ \
217217
uint16_t a = *(uint16_t*)pa; \
218-
float A = __gnu_h2f_ieee(a); \
218+
float A = julia__gnu_h2f_ieee(a); \
219219
if (osize == 16) { \
220220
float R; \
221221
OP(&R, A); \
222-
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
222+
*(uint16_t*)pr = julia__gnu_f2h_ieee(R); \
223223
} else { \
224224
OP((uint16_t*)pr, A); \
225225
} \
@@ -243,11 +243,11 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pr)
243243
{ \
244244
uint16_t a = *(uint16_t*)pa; \
245245
uint16_t b = *(uint16_t*)pb; \
246-
float A = __gnu_h2f_ieee(a); \
247-
float B = __gnu_h2f_ieee(b); \
246+
float A = julia__gnu_h2f_ieee(a); \
247+
float B = julia__gnu_h2f_ieee(b); \
248248
runtime_nbits = 16; \
249249
float R = OP(A, B); \
250-
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
250+
*(uint16_t*)pr = julia__gnu_f2h_ieee(R); \
251251
}
252252

253253
// float or integer inputs, bool output
@@ -268,8 +268,8 @@ static int jl_##name##16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEP
268268
{ \
269269
uint16_t a = *(uint16_t*)pa; \
270270
uint16_t b = *(uint16_t*)pb; \
271-
float A = __gnu_h2f_ieee(a); \
272-
float B = __gnu_h2f_ieee(b); \
271+
float A = julia__gnu_h2f_ieee(a); \
272+
float B = julia__gnu_h2f_ieee(b); \
273273
runtime_nbits = 16; \
274274
return OP(A, B); \
275275
}
@@ -309,12 +309,12 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pc,
309309
uint16_t a = *(uint16_t*)pa; \
310310
uint16_t b = *(uint16_t*)pb; \
311311
uint16_t c = *(uint16_t*)pc; \
312-
float A = __gnu_h2f_ieee(a); \
313-
float B = __gnu_h2f_ieee(b); \
314-
float C = __gnu_h2f_ieee(c); \
312+
float A = julia__gnu_h2f_ieee(a); \
313+
float B = julia__gnu_h2f_ieee(b); \
314+
float C = julia__gnu_h2f_ieee(c); \
315315
runtime_nbits = 16; \
316316
float R = OP(A, B, C); \
317-
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
317+
*(uint16_t*)pr = julia__gnu_f2h_ieee(R); \
318318
}
319319

320320

@@ -832,7 +832,7 @@ static inline int fpiseq##nbits(c_type a, c_type b) JL_NOTSAFEPOINT { \
832832
fpiseq_n(float, 32)
833833
fpiseq_n(double, 64)
834834
#define fpiseq(a,b) \
835-
sizeof(a) == sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b)
835+
sizeof(a) <= sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b)
836836

837837
#define fpislt_n(c_type, nbits) \
838838
static inline int fpislt##nbits(c_type a, c_type b) JL_NOTSAFEPOINT \
@@ -903,7 +903,7 @@ cvt_iintrinsic(LLVMFPtoUI, fptoui)
903903
if (!(osize < 8 * sizeof(a))) \
904904
jl_error("fptrunc: output bitsize must be < input bitsize"); \
905905
else if (osize == 16) \
906-
*(uint16_t*)pr = __gnu_f2h_ieee(a); \
906+
*(uint16_t*)pr = julia__gnu_f2h_ieee(a); \
907907
else if (osize == 32) \
908908
*(float*)pr = a; \
909909
else if (osize == 64) \

test/intrinsics.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,27 @@ end
152152
@test_intrinsic Core.Intrinsics.fptosi Int Float16(3.3) 3
153153
@test_intrinsic Core.Intrinsics.fptoui UInt Float16(3.3) UInt(3)
154154
end
155+
156+
if Sys.ARCH == :aarch64
157+
# On AArch64 we are following the `_Float16` ABI. Buthe these functions expect `Int16`.
158+
# TODO: SHould we have `Chalf == Int16` and `Cfloat16 == Float16`?
159+
extendhfsf2(x::Float16) = ccall("extern __extendhfsf2", llvmcall, Float32, (Int16,), reinterpret(Int16, x))
160+
gnu_h2f_ieee(x::Float16) = ccall("extern __gnu_h2f_ieee", llvmcall, Float32, (Int16,), reinterpret(Int16, x))
161+
truncsfhf2(x::Float32) = reinterpret(Float16, ccall("extern __truncsfhf2", llvmcall, Int16, (Float32,), x))
162+
gnu_f2h_ieee(x::Float32) = reinterpret(Float16, ccall("extern __gnu_f2h_ieee", llvmcall, Int16, (Float32,), x))
163+
truncdfhf2(x::Float64) = reinterpret(Float16, ccall("extern __truncdfhf2", llvmcall, Int16, (Float64,), x))
164+
else
165+
extendhfsf2(x::Float16) = ccall("extern __extendhfsf2", llvmcall, Float32, (Float16,), x)
166+
gnu_h2f_ieee(x::Float16) = ccall("extern __gnu_h2f_ieee", llvmcall, Float32, (Float16,), x)
167+
truncsfhf2(x::Float32) = ccall("extern __truncsfhf2", llvmcall, Float16, (Float32,), x)
168+
gnu_f2h_ieee(x::Float32) = ccall("extern __gnu_f2h_ieee", llvmcall, Float16, (Float32,), x)
169+
truncdfhf2(x::Float64) = ccall("extern __truncdfhf2", llvmcall, Float16, (Float64,), x)
170+
end
171+
172+
@testset "Float16 intrinsics (crt)" begin
173+
@test extendhfsf2(Float16(3.3)) == 3.3007812f0
174+
@test gnu_h2f_ieee(Float16(3.3)) == 3.3007812f0
175+
@test truncsfhf2(3.3f0) == Float16(3.3)
176+
@test gnu_f2h_ieee(3.3f0) == Float16(3.3)
177+
@test truncdfhf2(3.3) == Float16(3.3)
178+
end

0 commit comments

Comments
 (0)