Skip to content

Commit

Permalink
Implement Optimizer class
Browse files Browse the repository at this point in the history
The optimizer applies multiple passes on an expression.
  • Loading branch information
fsaintjacques committed Dec 31, 2019
1 parent 8a61d7e commit abfcd53
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 72 deletions.
46 changes: 39 additions & 7 deletions include/jitmap/query/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
namespace jitmap {
namespace query {

class ExprBuilder;

class Expr {
public:
enum Type {
Expand All @@ -33,6 +35,8 @@ class Expr {

template <typename Visitor>
auto Visit(Visitor&& v) const;
template <typename Visitor>
auto Visit(Visitor&& v);

// Convenience and debug operators
bool operator==(const Expr& rhs) const;
Expand All @@ -43,13 +47,17 @@ class Expr {
// Return all Reference expressions.
std::vector<std::string> Variables() const;

// Copy the expression
Expr* Copy(ExprBuilder* builder) const;

virtual ~Expr() {}

protected:
Expr(Type type) : type_(type) {}
Type type_;

private:
// Use Copy()
Expr(const Expr&) = delete;
};

Expand All @@ -68,6 +76,7 @@ class UnaryOpExpr : public OpExpr {
UnaryOpExpr(Expr* expr) : operand_(expr) {}

Expr* operand() const { return operand_; }
void SetOperand(Expr* expr) { operand_ = expr; }

protected:
Expr* operand_;
Expand All @@ -78,7 +87,10 @@ class BinaryOpExpr : public OpExpr {
BinaryOpExpr(Expr* lhs, Expr* rhs) : left_operand_(lhs), right_operand_(rhs) {}

Expr* left_operand() const { return left_operand_; }
void SetLeftOperand(Expr* left) { left_operand_ = left; }

Expr* right_operand() const { return right_operand_; }
void SetRightOperand(Expr* right) { right_operand_ = right; }

protected:
Expr* left_operand_;
Expand Down Expand Up @@ -164,19 +176,39 @@ template <typename Visitor>
auto Expr::Visit(Visitor&& v) const {
switch (type()) {
case EMPTY_LITERAL:
return v(static_cast<const EmptyBitmapExpr&>(*this));
return v(static_cast<const EmptyBitmapExpr*>(this));
case FULL_LITERAL:
return v(static_cast<const FullBitmapExpr*>(this));
case VARIABLE:
return v(static_cast<const VariableExpr*>(this));
case NOT_OPERATOR:
return v(static_cast<const NotOpExpr*>(this));
case AND_OPERATOR:
return v(static_cast<const AndOpExpr*>(this));
case OR_OPERATOR:
return v(static_cast<const OrOpExpr*>(this));
case XOR_OPERATOR:
return v(static_cast<const XorOpExpr*>(this));
}
}

template <typename Visitor>
auto Expr::Visit(Visitor&& v) {
switch (type()) {
case EMPTY_LITERAL:
return v(static_cast<EmptyBitmapExpr*>(this));
case FULL_LITERAL:
return v(static_cast<const FullBitmapExpr&>(*this));
return v(static_cast<FullBitmapExpr*>(this));
case VARIABLE:
return v(static_cast<const VariableExpr&>(*this));
return v(static_cast<VariableExpr*>(this));
case NOT_OPERATOR:
return v(static_cast<const NotOpExpr&>(*this));
return v(static_cast<NotOpExpr*>(this));
case AND_OPERATOR:
return v(static_cast<const AndOpExpr&>(*this));
return v(static_cast<AndOpExpr*>(this));
case OR_OPERATOR:
return v(static_cast<const OrOpExpr&>(*this));
return v(static_cast<OrOpExpr*>(this));
case XOR_OPERATOR:
return v(static_cast<const XorOpExpr&>(*this));
return v(static_cast<XorOpExpr*>(this));
}
}

Expand Down
41 changes: 33 additions & 8 deletions include/jitmap/query/optimizer.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
#pragma once

#include <jitmap/query/expr.h>
#include <optional>

#include <jitmap/query/matcher.h>

namespace jitmap {
namespace query {

class Expr;
class ExprBuilder;

class OptimizationPass {
public:
// Indicate that no optimization was done, see `Rewrite`.
Expand Down Expand Up @@ -79,17 +83,38 @@ class NotChainFolding final : public OptimizationPass {
Expr* Rewrite(const Expr& expr);
};

struct OptimizerOptions {
enum EnabledOptimizations : uint64_t {
CONSTANT_FOLDING = 1U << 1,
SAME_OPERAND_FOLDING = 1U << 2,
NOT_CHAIN_FOLDING = 1U << 3,
};

bool HasOptimization(enum EnabledOptimizations optimization) {
return enabled_optimizations & optimization;
}

static constexpr uint64_t kDefaultOptimizations =
CONSTANT_FOLDING | SAME_OPERAND_FOLDING | NOT_CHAIN_FOLDING;

uint64_t enabled_optimizations = kDefaultOptimizations;
};

class Optimizer {
public:
struct Options {
enum Flags : uint64_t {
CONSTANT_FOLDING = 1U << 1,
SAME_OPERAND_FOLDING = 1U << 2,
NOT_CHAIN_FOLDING = 1U << 3,
};
};
Optimizer(ExprBuilder* builder, OptimizerOptions options = {});

const OptimizerOptions& options() const { return options_; }

Expr* Optimize(const Expr& input);

private:
ExprBuilder* builder_;
OptimizerOptions options_;

std::optional<ConstantFolding> constant_folding_;
std::optional<SameOperandFolding> same_operand_folding_;
std::optional<NotChainFolding> not_chain_folding_;
};

} // namespace query
Expand Down
20 changes: 10 additions & 10 deletions src/jitmap/query/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,32 @@ struct ExprCodeGenVisitor {
llvm::IRBuilder<>& builder;
llvm::Type* vector_type;

llvm::Value* operator()(const VariableExpr& e) { return FindBitmapByName(e.value()); }
llvm::Value* operator()(const VariableExpr* e) { return FindBitmapByName(e->value()); }

llvm::Value* operator()(const EmptyBitmapExpr& e) {
llvm::Value* operator()(const EmptyBitmapExpr*) {
return llvm::ConstantInt::get(vector_type, 0UL);
}

llvm::Value* operator()(const FullBitmapExpr& e) {
llvm::Value* operator()(const FullBitmapExpr*) {
return llvm::ConstantInt::get(vector_type, UINT64_MAX);
}

llvm::Value* operator()(const NotOpExpr& e) {
auto operand = e.operand()->Visit(*this);
llvm::Value* operator()(const NotOpExpr* e) {
auto operand = e->operand()->Visit(*this);
return builder.CreateNot(operand);
}

llvm::Value* operator()(const AndOpExpr& e) {
llvm::Value* operator()(const AndOpExpr* e) {
auto [lhs, rhs] = VisitBinary(e);
return builder.CreateAnd(lhs, rhs);
}

llvm::Value* operator()(const OrOpExpr& e) {
llvm::Value* operator()(const OrOpExpr* e) {
auto [lhs, rhs] = VisitBinary(e);
return builder.CreateOr(lhs, rhs);
}

llvm::Value* operator()(const XorOpExpr& e) {
llvm::Value* operator()(const XorOpExpr* e) {
auto [lhs, rhs] = VisitBinary(e);
return builder.CreateXor(lhs, rhs);
}
Expand All @@ -64,8 +64,8 @@ struct ExprCodeGenVisitor {
return result->second;
}

std::pair<llvm::Value*, llvm::Value*> VisitBinary(const BinaryOpExpr& e) {
return {e.left_operand()->Visit(*this), e.right_operand()->Visit(*this)};
std::pair<llvm::Value*, llvm::Value*> VisitBinary(const BinaryOpExpr* e) {
return {e->left_operand()->Visit(*this), e->right_operand()->Visit(*this)};
}
};

Expand Down
67 changes: 44 additions & 23 deletions src/jitmap/query/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,19 @@ bool Expr::IsBinaryOperator() const {
bool Expr::operator==(const Expr& rhs) const {
// Pointer shorcut.
if (this == &rhs) return true;
return this->Visit([&](const auto& left) {
return this->Visit([&](const auto* left) {
if (type() != rhs.type()) return false;

using E = std::decay_t<decltype(left)>;
const E& right = static_cast<const E&>(rhs);
using E = std::decay_t<std::remove_pointer_t<decltype(left)>>;
const E* right = static_cast<const E*>(&rhs);

if constexpr (is_literal<E>::value) return true;
if constexpr (is_variable<E>::value) return left.value() == right.value();
if constexpr (is_unary_op<E>::value) return *left.operand() == *right.operand();
if constexpr (is_unary_op<E>::value) return *left.operand() == *right.operand();
if constexpr (is_variable<E>::value) return left->value() == right->value();
if constexpr (is_unary_op<E>::value) return *left->operand() == *right->operand();
if constexpr (is_unary_op<E>::value) return *left->operand() == *right->operand();
if constexpr (is_binary_op<E>::value) {
return (*left.left_operand() == *right.left_operand()) &&
(*left.right_operand() == *right.right_operand());
return (*left->left_operand() == *right->left_operand()) &&
(*left->right_operand() == *right->right_operand());
}

return false;
Expand All @@ -91,9 +91,9 @@ static const char* OpToChar(Expr::Type op) {
}

std::string Expr::ToString() const {
return Visit([&](const auto& e) -> std::string {
using E = std::decay_t<decltype(e)>;
auto type = e.type();
return Visit([&](const auto* e) -> std::string {
using E = std::decay_t<std::remove_pointer_t<decltype(e)>>;
auto type = e->type();
auto symbol = OpToChar(type);

std::stringstream ss;
Expand All @@ -103,48 +103,69 @@ std::string Expr::ToString() const {
}

if constexpr (is_variable<E>::value) {
ss << symbol << e.value();
ss << symbol << e->value();
}

if constexpr (is_unary_op<E>::value) {
ss << symbol << e.operand()->ToString();
ss << symbol << e->operand()->ToString();
}

if constexpr (is_binary_op<E>::value) {
auto left = e.left_operand()->ToString();
auto right = e.right_operand()->ToString();
auto left = e->left_operand()->ToString();
auto right = e->right_operand()->ToString();
ss << "(" << left << " " << symbol << " " << right << ")";
}

return ss.str();
});
}

static void CollectVariables(const Expr& expr,
static void CollectVariables(const Expr* expr,
std::unordered_set<std::string>& unique_variables,
std::vector<std::string>& variables) {
expr.Visit([&unique_variables, &variables](const auto& e) {
using E = std::decay_t<decltype(e)>;
expr->Visit([&unique_variables, &variables](const auto* e) {
using E = std::decay_t<std::remove_pointer_t<decltype(e)>>;

if constexpr (is_variable<E>::value) {
auto var = e.value();
auto var = e->value();
if (unique_variables.insert(var).second) variables.emplace_back(var);
} else if constexpr (is_unary_op<E>::value) {
CollectVariables(*e.operand(), unique_variables, variables);
CollectVariables(e->operand(), unique_variables, variables);
} else if constexpr (is_binary_op<E>::value) {
CollectVariables(*e.left_operand(), unique_variables, variables);
CollectVariables(*e.right_operand(), unique_variables, variables);
CollectVariables(e->left_operand(), unique_variables, variables);
CollectVariables(e->right_operand(), unique_variables, variables);
}
});
}

std::vector<std::string> Expr::Variables() const {
std::unordered_set<std::string> unique_variables;
std::vector<std::string> variables;
CollectVariables(*this, unique_variables, variables);
CollectVariables(this, unique_variables, variables);
return variables;
}

Expr* Expr::Copy(ExprBuilder* b) const {
return Visit([b](const auto* e) -> Expr* {
using E = std::decay_t<std::remove_pointer_t<decltype(e)>>;
if constexpr (is_variable<E>::value) {
return b->Var(e->value());
} else if constexpr (is_literal<E>::value) {
return (e->type() == FULL_LITERAL) ? b->FullBitmap() : b->EmptyBitmap();
} else if constexpr (is_not_op<E>::value) {
return b->Not(e->operand()->Copy(b));
} else if constexpr (is_and_op<E>::value) {
return b->And(e->left_operand()->Copy(b), e->right_operand()->Copy(b));
} else if constexpr (is_or_op<E>::value) {
return b->Or(e->left_operand()->Copy(b), e->right_operand()->Copy(b));
} else if constexpr (is_xor_op<E>::value) {
return b->Xor(e->left_operand()->Copy(b), e->right_operand()->Copy(b));
}

return nullptr;
});
}

std::ostream& operator<<(std::ostream& os, const Expr& e) { return os << e.ToString(); }
std::ostream& operator<<(std::ostream& os, Expr* e) { return os << *e; }

Expand Down
12 changes: 6 additions & 6 deletions src/jitmap/query/matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,24 @@ OperandMatcher::OperandMatcher(Matcher* matcher, Mode mode)
: matcher_(matcher), mode_(mode) {}

bool OperandMatcher::Match(const Expr& expr) const {
return expr.Visit([&](const auto& e) {
using E = std::decay_t<decltype(e)>;
return expr.Visit([&](const auto* e) {
using E = std::decay_t<std::remove_pointer_t<decltype(e)>>;

auto mode = this->mode_;
auto& matcher = *this->matcher_;

if constexpr (is_unary_op<E>::value) {
return matcher(static_cast<const NotOpExpr&>(e).operand());
if constexpr (is_not_op<E>::value) {
return matcher(e->operand());
}

if constexpr (is_binary_op<E>::value) {
bool left = matcher(e.left_operand());
bool left = matcher(e->left_operand());

// Short-circuit
if (left && mode == Mode::ANY) return true;
if (!left && mode == Mode::ALL) return false;

bool right = matcher(e.right_operand());
bool right = matcher(e->right_operand());
return (mode == Mode::ANY) ? left || right : left && right;
}

Expand Down
Loading

0 comments on commit abfcd53

Please sign in to comment.