-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
llvm-muladd.cpp
131 lines (120 loc) · 3.8 KB
/
llvm-muladd.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// This file is a part of Julia. License is MIT: https://julialang.org/license
#define DEBUG_TYPE "combine_muladd"
#undef DEBUG
#include "llvm-version.h"
#include <llvm/IR/Value.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/IntrinsicInst.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Operator.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/Pass.h>
#include <llvm/Support/Debug.h>
#include "julia.h"
#include "julia_assert.h"
using namespace llvm;
/**
* Combine
* ```
* %v0 = fmul ... %a, %b
* %v = fadd fast ... %v0, %c
* ```
* to
* `%v = call fast @llvm.fmuladd.<...>(... %a, ... %b, ... %c)`
* when `%v0` has no other use
*/
struct CombineMulAdd : public FunctionPass {
static char ID;
CombineMulAdd() : FunctionPass(ID)
{}
private:
bool runOnFunction(Function &F) override;
};
// Return true if this function shouldn't be called again on the other operand
// This will always return false on LLVM 5.0+
static bool checkCombine(Module *m, Instruction *addOp, Value *maybeMul, Value *addend,
bool negadd, bool negres)
{
auto mulOp = dyn_cast<Instruction>(maybeMul);
if (!mulOp || mulOp->getOpcode() != Instruction::FMul)
return false;
if (!mulOp->hasOneUse())
return false;
#if JL_LLVM_VERSION >= 50000
// On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us.
auto fmf = mulOp->getFastMathFlags();
fmf.setAllowContract(true);
mulOp->copyFastMathFlags(fmf);
return false;
#else
IRBuilder<> builder(m->getContext());
builder.SetInsertPoint(addOp);
auto mul1 = mulOp->getOperand(0);
auto mul2 = mulOp->getOperand(1);
Value *muladdf = Intrinsic::getDeclaration(m, Intrinsic::fmuladd, addOp->getType());
if (negadd) {
auto newaddend = builder.CreateFNeg(addend);
// Might be a const
if (auto neginst = dyn_cast<Instruction>(newaddend))
neginst->setHasUnsafeAlgebra(true);
addend = newaddend;
}
Instruction *newv = builder.CreateCall(muladdf, {mul1, mul2, addend});
newv->setHasUnsafeAlgebra(true);
if (negres) {
// Shouldn't be a constant
newv = cast<Instruction>(builder.CreateFNeg(newv));
newv->setHasUnsafeAlgebra(true);
}
addOp->replaceAllUsesWith(newv);
addOp->eraseFromParent();
mulOp->eraseFromParent();
return true;
#endif
}
bool CombineMulAdd::runOnFunction(Function &F)
{
Module *m = F.getParent();
for (auto &BB: F) {
for (auto it = BB.begin(); it != BB.end();) {
auto &I = *it;
it++;
switch (I.getOpcode()) {
case Instruction::FAdd: {
#if JL_LLVM_VERSION >= 60000
if (!I.isFast())
#else
if (!I.hasUnsafeAlgebra())
#endif
continue;
checkCombine(m, &I, I.getOperand(0), I.getOperand(1), false, false) ||
checkCombine(m, &I, I.getOperand(1), I.getOperand(0), false, false);
break;
}
case Instruction::FSub: {
#if JL_LLVM_VERSION >= 60000
if (!I.isFast())
#else
if (!I.hasUnsafeAlgebra())
#endif
continue;
checkCombine(m, &I, I.getOperand(0), I.getOperand(1), true, false) ||
checkCombine(m, &I, I.getOperand(1), I.getOperand(0), true, true);
break;
}
default:
break;
}
}
}
return true;
}
char CombineMulAdd::ID = 0;
static RegisterPass<CombineMulAdd> X("CombineMulAdd", "Combine mul and add to muladd",
false /* Only looks at CFG */,
false /* Analysis Pass */);
Pass *createCombineMulAddPass()
{
return new CombineMulAdd();
}