Skip to content

Commit

Permalink
Support both Float16 ABIs depending on LLVM and platform
Browse files Browse the repository at this point in the history
There are two Float16 ABIs in the wild, one for platforms that have a
defing register and the original one where we used i16.

LLVM 15 follows GCC and uses the new ABI on x86/ARM but not PPC.

Co-authored-by: Gabriel Baraldi <[email protected]>
  • Loading branch information
vchuravy and gbaraldi committed Apr 26, 2023
1 parent b12ddca commit ed77914
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ static void reportWriterError(const ErrorInfoBase &E)
jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str());
}

#if JULIA_FLOAT16_ABI == 1
static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
{
Function *target = M.getFunction(alias);
Expand All @@ -509,7 +510,7 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT
auto val = builder.CreateCall(target, CallArgs);
builder.CreateRet(val);
}

#endif
void multiversioning_preannotate(Module &M);

// See src/processor.h for documentation about this table. Corresponds to jl_image_shard_t.
Expand Down Expand Up @@ -942,6 +943,8 @@ struct ShardTimers {
}
};

void emitFloat16Wrappers(Module &M, bool external);

// Perform the actual optimization and emission of the output files
static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *outputs, const std::string *names,
NewArchiveMember *unopt, NewArchiveMember *opt, NewArchiveMember *obj, NewArchiveMember *asm_,
Expand Down Expand Up @@ -1002,7 +1005,9 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out
}
}
// no need to inject aliases if we have no functions

if (inject_aliases) {
#if JULIA_FLOAT16_ABI == 1
// We would like to emit an alias or an weakref alias to redirect these symbols
// but LLVM doesn't let us emit a GlobalAlias to a declaration...
// So for now we inject a definition of these functions that calls our runtime
Expand All @@ -1017,8 +1022,10 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out
FunctionType::get(Type::getHalfTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false));
injectCRTAlias(M, "__truncdfhf2", "julia__truncdfhf2",
FunctionType::get(Type::getHalfTy(M.getContext()), { Type::getDoubleTy(M.getContext()) }, false));
#else
emitFloat16Wrappers(M, false);
#endif
}

timers.optimize.stopTimer();

if (opt) {
Expand Down
55 changes: 55 additions & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5809,6 +5809,7 @@ static void emit_cfunc_invalidate(
prepare_call_in(gf_thunk->getParent(), jlapplygeneric_func));
}

#include <iostream>
static Function* gen_cfun_wrapper(
Module *into, jl_codegen_params_t &params,
const function_sig_t &sig, jl_value_t *ff, const char *aliasname,
Expand Down Expand Up @@ -8696,6 +8697,57 @@ static JuliaVariable *julia_const_gv(jl_value_t *val)
}
return nullptr;
}
//Float16 fun
static void makeCastCall(Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled, bool external)
{
Function *calledFun = M.getFunction(calledName);
if (!calledFun) {
calledFun = Function::Create(FTcalled, Function::ExternalLinkage, calledName, M);
}
auto linkage = external ? Function::ExternalLinkage : Function::InternalLinkage;
auto wrapperFun = Function::Create(FTwrapper, linkage, wrapperName, M);
wrapperFun->addFnAttr(Attribute::AlwaysInline);
llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", wrapperFun));
SmallVector<Value *, 4> CallArgs;
if (wrapperFun->arg_size() != calledFun->arg_size()){
llvm::errs() << "FATAL ERROR: Can't match wrapper to called function";
abort();
}
for (auto wrapperArg = wrapperFun->arg_begin(), calledArg = calledFun->arg_begin();
wrapperArg != wrapperFun->arg_end() && calledArg != calledFun->arg_end(); ++wrapperArg, ++calledArg)
{
CallArgs.push_back(builder.CreateBitCast(wrapperArg, calledArg->getType()));
}
auto val = builder.CreateCall(calledFun, CallArgs);
auto retval = builder.CreateBitCast(val,wrapperFun->getReturnType());
builder.CreateRet(retval);
}

#if JULIA_FLOAT16_ABI == 2
void emitFloat16Wrappers(Module &M, bool external)
{
auto &ctx = M.getContext();
makeCastCall(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
makeCastCall(M, "__extendhfsf2", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
makeCastCall(M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
makeCastCall(M, "__truncsfhf2", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
makeCastCall(M, "__truncdfhf2", "julia__truncdfhf2", FunctionType::get(Type::getHalfTy(ctx), { Type::getDoubleTy(ctx) }, false),
FunctionType::get(Type::getInt16Ty(ctx), { Type::getDoubleTy(ctx) }, false), external);
}

static void init_f16_funcs(void)
{
auto ctx = jl_ExecutionEngine->acquireContext();
auto TSM = jl_create_ts_module("F16Wrappers", ctx, imaging_default());
auto aliasM = TSM.getModuleUnlocked();
emitFloat16Wrappers(*aliasM, true);
jl_ExecutionEngine->addModule(std::move(TSM));
}
#endif

static void init_jit_functions(void)
{
Expand Down Expand Up @@ -8935,6 +8987,9 @@ extern "C" JL_DLLEXPORT void jl_init_codegen_impl(void)
jl_init_llvm();
// Now that the execution engine exists, initialize all modules
init_jit_functions();
#if JULIA_FLOAT16_ABI == 2
init_f16_funcs();
#endif
}

extern "C" JL_DLLEXPORT void jl_teardown_codegen_impl() JL_NOTSAFEPOINT
Expand Down
2 changes: 2 additions & 0 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,7 @@ JuliaOJIT::JuliaOJIT()

JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);

#if JULIA_FLOAT16_ABI == 1
orc::SymbolAliasMap jl_crt = {
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
Expand All @@ -1391,6 +1392,7 @@ JuliaOJIT::JuliaOJIT()
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
};
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
#endif

#ifdef MSAN_EMUTLS_WORKAROUND
orc::SymbolMap msan_crt;
Expand Down
10 changes: 10 additions & 0 deletions src/llvm-version.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <llvm/Config/llvm-config.h>
#include "julia_assert.h"
#include "platform.h"

// The LLVM version used, JL_LLVM_VERSION, is represented as a 5-digit integer
// of the form ABBCC, where A is the major version, B is minor, and C is patch.
Expand All @@ -17,6 +18,15 @@
#define JL_LLVM_OPAQUE_POINTERS 1
#endif

// Pre GCC 12 libgcc defined the ABI for Float16->Float32
// to take an i16. GCC 12 silently changed the ABI to now pass
// Float16 in Float32 registers.
#if JL_LLVM_VERSION < 150000 || defined(_CPU_PPC64_) || defined(_CPU_PPC_)
#define JULIA_FLOAT16_ABI 1
#else
#define JULIA_FLOAT16_ABI 2
#endif

#ifdef __cplusplus
#if defined(__GNUC__) && (__GNUC__ >= 9)
// Added in GCC 9, this warning is annoying
Expand Down

0 comments on commit ed77914

Please sign in to comment.