Skip to content

Commit

Permalink
Refactor compiler code
Browse files Browse the repository at this point in the history
- Make the QueryCompiler object public.
  • Loading branch information
fsaintjacques committed Dec 15, 2019
1 parent 8be6bbf commit 9cbf398
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 63 deletions.
39 changes: 38 additions & 1 deletion include/jitmap/query/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

#include <jitmap/query/query.h>

namespace llvm {
class Function;
class Module;
} // namespace llvm

namespace jitmap {
namespace query {

Expand All @@ -10,6 +15,38 @@ class CompilerException : public util::Exception {
using Exception::Exception;
};

std::string Compile(Query& query);
struct CompilerOptions {
// Controls the number of scalar, i.e. width of the vector aggregate value in
// the loop. A value of 1 will emit a scalar value instead of a vector value.
//
// Usually a small power of 2, e.g. 1, 2, 4, 8 or 16. See the documentation
// of your hardware platform for the optimal value.
//
// LLVM's optimizer is able to perform the auto-vectorization of the loop. In
// cases where it can't, change the vector width here.
uint8_t vector_width_ = 1;

// Controls the width of each scalar (in bits). See the documentation
// of your hardware platform for the optimal value.
//
// LLVM should be able to figure this by default.
uint8_t scalar_width_ = 64;
};

class QueryCompiler {
public:
QueryCompiler(const std::string& module_name, CompilerOptions options = {});

llvm::Function* Compile(const Query& query);

std::string DumpIR() const;

~QueryCompiler();

private:
class Impl;
std::unique_ptr<Impl> impl_;
};

} // namespace query
} // namespace jitmap
116 changes: 57 additions & 59 deletions src/jitmap/query/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,8 @@ namespace query {
using InputsOutputArguments = std::pair<std::vector<llvm::Argument*>, llvm::Argument*>;
InputsOutputArguments PartitionFunctionArguments(llvm::Function* fn);

struct CompilerOptions {
uint64_t word_size() const { return kBitsPerContainer / scalar_width_; }

uint8_t vector_width_ = 1;
uint8_t scalar_width_ = 64;
};

// Generate the hot section of the loop. Takes an expression and reduce it to a
// single (scalar or vector) value.
struct ExprCodeGenVisitor {
llvm::Value* operator()(const VariableExpr& e) { return FindBitmapByName(e.value()); }

Expand Down Expand Up @@ -76,38 +71,28 @@ struct ExprCodeGenVisitor {
llvm::Type* vector_type;
};

class QueryCompiler {
class QueryCompiler::Impl {
public:
QueryCompiler(Query& query)
: query_(query), ctx_(), builder_(ctx_), module_(query_.name(), ctx_) {}
Impl(const std::string& module_name, CompilerOptions options)
: ctx_(),
builder_(ctx_),
module_(module_name, ctx_),
options_(std::move(options)) {}

void Compile() {
module_.setSourceFileName(query_.expr().ToString());
auto fn = FunctionDeclForQuery(query_.parameters().size(), query_.name(), &module_);
llvm::Function* Compile(const Query& query) { return FunctionDeclForQuery(query); }
llvm::Module* Module() { return &module_; }

private:
void FunctionCodeGen(const Query& query, llvm::Function* fn) {
auto entry_block = llvm::BasicBlock::Create(ctx_, "entry", fn);
builder_.SetInsertPoint(entry_block);

LoopCodeGen(fn);

builder_.CreateRetVoid();
}

std::string DebugIR() const {
std::string buffer;
llvm::raw_string_ostream ss{buffer};
module_.print(ss, nullptr);
return ss.str();
}

private:
void LoopCodeGen(llvm::Function* fn) {
// Constants
auto induction_type = llvm::Type::getInt64Ty(ctx_);
auto zero = llvm::ConstantInt::get(induction_type, 0);
auto step = llvm::ConstantInt::get(induction_type, vector_width());
auto n_words = llvm::ConstantInt::get(induction_type, word_size());

auto entry_block = builder_.GetInsertBlock();
auto loop_block = llvm::BasicBlock::Create(ctx_, "loop", fn);
auto after_block = llvm::BasicBlock::Create(ctx_, "after_loop", fn);

Expand All @@ -123,7 +108,7 @@ class QueryCompiler {
auto i = builder_.CreatePHI(induction_type, 2, "i");
i->addIncoming(zero, entry_block);

LoopBodyCodeGen(fn, i);
LoopBodyCodeGen(query, fn, i);

// i += step
auto next_i = builder_.CreateAdd(i, step, "next_i");
Expand All @@ -134,40 +119,40 @@ class QueryCompiler {
}

builder_.SetInsertPoint(after_block);
builder_.CreateRetVoid();
}

void LoopBodyCodeGen(llvm::Function* fn, llvm::Value* loop_var) {
void LoopBodyCodeGen(const Query& query, llvm::Function* fn, llvm::Value* loop_idx) {
auto [inputs, output] = PartitionFunctionArguments(fn);

auto namify = [](std::string key, size_t i) { return key + "_" + std::to_string(i); };

auto load_vector_inst = [&](auto input, size_t i) {
auto gep = builder_.CreateInBoundsGEP(input, {loop_var}, namify("gep", i));
// Cast previous reference as a vector-type.
auto bitcast = builder_.CreateBitCast(gep, VectorPtrType(), namify("bitcast", i));
// Load in a vector register
return builder_.CreateLoad(bitcast, namify("load", i));
// Load scalar at index for given bitmap
auto load_vector_inst = [&](auto bitmap_addr, size_t i) {
auto namify = [&i](std::string key) { return key + "_" + std::to_string(i); };
// Compute the address to load
auto gep = builder_.CreateInBoundsGEP(bitmap_addr, {loop_idx}, namify("gep"));
// Cast previous address as a vector-type.
auto bitcast = builder_.CreateBitCast(gep, VectorPtrType(), namify("bitcast"));
// Load in a register
return builder_.CreateLoad(bitcast, namify("load"));
};

std::vector<llvm::Value*> bitmaps;
for (size_t i = 0; i < inputs.size(); i++) {
bitmaps.push_back(load_vector_inst(inputs[i], i));
}

auto result = ExprCodeGen(bitmaps);

auto gep = builder_.CreateInBoundsGEP(output, {loop_var}, "gep_output");
auto bitcast = builder_.CreateBitCast(gep, VectorPtrType(), "bitcast_output");
builder_.CreateStore(result, bitcast);
}

llvm::Value* ExprCodeGen(std::vector<llvm::Value*>& bitmaps) {
std::unordered_map<std::string, llvm::Value*> keyed_bitmaps;
const auto& parameters = query_.parameters();
const auto& parameters = query.parameters();
for (size_t i = 0; i < bitmaps.size(); i++) {
keyed_bitmaps.emplace(parameters[i], bitmaps[i]);
}
return query_.expr().Visit(ExprCodeGenVisitor{keyed_bitmaps, builder_, VectorType()});

ExprCodeGenVisitor visitor{keyed_bitmaps, builder_, VectorType()};
auto result = query.expr().Visit(visitor);

auto gep = builder_.CreateInBoundsGEP(output, {loop_idx}, "gep_output");
auto bitcast = builder_.CreateBitCast(gep, VectorPtrType(), "bitcast_output");
builder_.CreateStore(result, bitcast);
}

llvm::FunctionType* FunctionTypeForArguments(size_t n_arguments) {
Expand All @@ -177,15 +162,17 @@ class QueryCompiler {
return llvm::FunctionType::get(return_type, argument_types, is_var_args);
}

llvm::Function* FunctionDeclForQuery(size_t n_arguments, const std::string& query_name,
llvm::Module* module) {
llvm::Function* FunctionDeclForQuery(const Query& query) {
size_t n_arguments = query.parameters().size();
auto query_name = query.name();

auto fn_type = FunctionTypeForArguments(n_arguments);

// The generated function will be exposed as an external symbol, i.e the
// symbol will be globally visible. This would be equivalent to defining a
// symbol with the `extern` storage classifier.
auto linkage = llvm::Function::ExternalLinkage;
auto fn = llvm::Function::Create(fn_type, linkage, query_name, module);
auto fn = llvm::Function::Create(fn_type, linkage, query_name, &module_);

// The generated objects are accessed by taking the symbol address and
// casting to a function type. Thus, we must use the C calling convention.
Expand All @@ -205,6 +192,8 @@ class QueryCompiler {
output->setName("out");
output->addAttr(llvm::Attribute::NoCapture);

FunctionCodeGen(query, fn);

return fn;
}

Expand All @@ -224,17 +213,32 @@ class QueryCompiler {
}
llvm::Type* VectorPtrType() { return VectorType()->getPointerTo(); }

uint32_t word_size() const { return options_.word_size(); }
uint8_t vector_width() const { return options_.vector_width_; }
uint8_t scalar_width() const { return options_.scalar_width_; }
uint32_t word_size() const { return kBitsPerContainer / scalar_width(); }

Query& query_;
llvm::LLVMContext ctx_;
llvm::IRBuilder<> builder_;
llvm::Module module_;
CompilerOptions options_;
};

QueryCompiler::QueryCompiler(const std::string& module_name, CompilerOptions options)
: impl_(std::make_unique<QueryCompiler::Impl>(module_name, std::move(options))) {}

llvm::Function* QueryCompiler::Compile(const Query& query) {
return impl_->Compile(query);
}

std::string QueryCompiler::DumpIR() const {
std::string buffer;
llvm::raw_string_ostream ss{buffer};
impl_->Module()->print(ss, nullptr);
return ss.str();
}

QueryCompiler::~QueryCompiler() {}

InputsOutputArguments PartitionFunctionArguments(llvm::Function* fn) {
InputsOutputArguments io;

Expand All @@ -251,11 +255,5 @@ InputsOutputArguments PartitionFunctionArguments(llvm::Function* fn) {
return io;
}

std::string Compile(Query& query) {
QueryCompiler compiler(query);
compiler.Compile();
return compiler.DebugIR();
};

} // namespace query
} // namespace jitmap
5 changes: 3 additions & 2 deletions tests/query/compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ namespace query {
class QueryCompilerTest : public QueryTest {};

TEST_F(QueryCompilerTest, Basic) {
auto query = Query::Make("test_query", Not(Var("a")));
Compile(*query);
auto query = Query::Make("not_a", Not(Var("a")));

QueryCompiler("my-module", {}).Compile(*query);
}

} // namespace query
Expand Down
4 changes: 3 additions & 1 deletion tools/jitmap_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ int main(int argc, char** argv) {
ExprBuilder builder;
auto expr = Parse(argv[1], &builder);
auto query = Query::Make("query", expr);
std::cout << Compile(*query) << "\n";
auto compiler = QueryCompiler("jitmap-ir-module", {});
compiler.Compile(*query);
std::cout << compiler.DumpIR() << "\n";

return 0;
}

0 comments on commit 9cbf398

Please sign in to comment.