diff --git a/src/cc/frontends/clang/b_frontend_action.cc b/src/cc/frontends/clang/b_frontend_action.cc index 967f0700fc30..d8e1e8db2ca9 100644 --- a/src/cc/frontends/clang/b_frontend_action.cc +++ b/src/cc/frontends/clang/b_frontend_action.cc @@ -90,8 +90,9 @@ using namespace clang; class ProbeChecker : public RecursiveASTVisitor { public: - explicit ProbeChecker(Expr *arg, const set &ptregs) - : needs_probe_(false), is_transitive_(false), ptregs_(ptregs) { + explicit ProbeChecker(Expr *arg, const set &ptregs, bool track_helpers) + : needs_probe_(false), is_transitive_(false), ptregs_(ptregs), + track_helpers_(track_helpers) { if (arg) { TraverseStmt(arg); if (arg->getType()->isPointerType()) @@ -100,9 +101,10 @@ class ProbeChecker : public RecursiveASTVisitor { } bool VisitCallExpr(CallExpr *E) { needs_probe_ = false; - if (VarDecl *V = dyn_cast(E->getCalleeDecl())) { + if (!track_helpers_) + return false; + if (VarDecl *V = dyn_cast(E->getCalleeDecl())) needs_probe_ = V->getName() == "bpf_get_current_task"; - } return false; } bool VisitMemberExpr(MemberExpr *M) { @@ -123,6 +125,7 @@ class ProbeChecker : public RecursiveASTVisitor { bool needs_probe_; bool is_transitive_; const set &ptregs_; + bool track_helpers_; }; // Visit a piece of the AST and mark it as needing probe reads @@ -152,7 +155,7 @@ bool MapVisitor::VisitCallExpr(CallExpr *Call) { return true; if (memb_name == "update" || memb_name == "insert") { - if (ProbeChecker(Call->getArg(1), ptregs_).needs_probe()) { + if (ProbeChecker(Call->getArg(1), ptregs_, true).needs_probe()) { m_.insert(Ref->getDecl()); } } @@ -162,12 +165,12 @@ bool MapVisitor::VisitCallExpr(CallExpr *Call) { return true; } -ProbeVisitor::ProbeVisitor(ASTContext &C, Rewriter &rewriter, set &m) : - C(C), rewriter_(rewriter), m_(m) {} +ProbeVisitor::ProbeVisitor(ASTContext &C, Rewriter &rewriter, set &m, bool track_helpers) : + C(C), rewriter_(rewriter), m_(m), track_helpers_(track_helpers) {} bool ProbeVisitor::VisitVarDecl(VarDecl *Decl) { if (Expr *E = Decl->getInit()) { - if (ProbeChecker(E, ptregs_).is_transitive() || IsContextMemberExpr(E)) { + if (ProbeChecker(E, ptregs_, track_helpers_).is_transitive() || IsContextMemberExpr(E)) { set_ptreg(Decl); } } @@ -178,7 +181,7 @@ bool ProbeVisitor::VisitCallExpr(CallExpr *Call) { if (F->hasBody()) { unsigned i = 0; for (auto arg : Call->arguments()) { - if (ProbeChecker(arg, ptregs_).needs_probe()) + if (ProbeChecker(arg, ptregs_, track_helpers_).needs_probe()) ptregs_.insert(F->getParamDecl(i)); ++i; } @@ -194,7 +197,7 @@ bool ProbeVisitor::VisitBinaryOperator(BinaryOperator *E) { if (!E->isAssignmentOp()) return true; // copy probe attribute from RHS to LHS if present - if (ProbeChecker(E->getRHS(), ptregs_).is_transitive()) { + if (ProbeChecker(E->getRHS(), ptregs_, track_helpers_).is_transitive()) { ProbeSetter setter(&ptregs_); setter.TraverseStmt(E->getLHS()); } else if (E->getRHS()->getStmtClass() == Stmt::CallExprClass) { @@ -227,7 +230,7 @@ bool ProbeVisitor::VisitUnaryOperator(UnaryOperator *E) { return true; if (memb_visited_.find(E) != memb_visited_.end()) return true; - if (!ProbeChecker(E, ptregs_).needs_probe()) + if (!ProbeChecker(E, ptregs_, track_helpers_).needs_probe()) return true; memb_visited_.insert(E); Expr *sub = E->getSubExpr(); @@ -264,7 +267,7 @@ bool ProbeVisitor::VisitMemberExpr(MemberExpr *E) { // Checks to see if the expression references something that needs to be run // through bpf_probe_read. - if (!ProbeChecker(base, ptregs_).needs_probe()) + if (!ProbeChecker(base, ptregs_, track_helpers_).needs_probe()) return true; string rhs = rewriter_.getRewrittenText(expansionRange(SourceRange(rhs_start, E->getLocEnd()))); @@ -889,44 +892,71 @@ BTypeConsumer::BTypeConsumer(ASTContext &C, BFrontendAction &fe, : fe_(fe), map_visitor_(m), btype_visitor_(C, fe), - probe_visitor_(C, rewriter, m) {} + probe_visitor1_(C, rewriter, m, true), + probe_visitor2_(C, rewriter, m, false) {} -bool BTypeConsumer::HandleTopLevelDecl(DeclGroupRef Group) { - for (auto D : Group) { +void BTypeConsumer::HandleTranslationUnit(ASTContext &Context) { + DeclContext::decl_iterator it; + DeclContext *DC = TranslationUnitDecl::castToDeclContext(Context.getTranslationUnitDecl()); + + /** + * In a first traversal, ProbeVisitor tracks external pointers identified + * through each function's arguments and replaces their dereferences with + * calls to bpf_probe_read. It also passes all identified pointers to + * external addresses to MapVisitor. + */ + for (it = DC->decls_begin(); it != DC->decls_end(); it++) { + Decl *D = *it; if (FunctionDecl *F = dyn_cast(D)) { if (fe_.is_rewritable_ext_func(F)) { for (auto arg : F->parameters()) { - if (arg != F->getParamDecl(0) && !arg->getType()->isFundamentalType()) { - map_visitor_.set_ptreg(arg); + if (arg == F->getParamDecl(0)) { + probe_visitor1_.set_ctx(arg); + } else if (!arg->getType()->isFundamentalType()) { + probe_visitor1_.set_ptreg(arg); } } - map_visitor_.TraverseDecl(D); + + probe_visitor1_.TraverseDecl(D); + for (auto ptreg : probe_visitor1_.get_ptregs()) { + map_visitor_.set_ptreg(ptreg); + } } } } - return true; -} -void BTypeConsumer::HandleTranslationUnit(ASTContext &Context) { - DeclContext::decl_iterator it; - DeclContext *DC = TranslationUnitDecl::castToDeclContext(Context.getTranslationUnitDecl()); + /** + * MapVisitor uses external pointers identified by the first ProbeVisitor + * traversal to identify all maps with external pointers as values. + * MapVisitor runs only after ProbeVisitor finished its traversal of the + * whole translation unit to clearly separate the role of each ProbeVisitor's + * traversal: the first tracks external pointers from function arguments, + * whereas the second tracks external pointers from maps. Without this clear + * separation, ProbeVisitor might attempt to replace several times the same + * dereferences. + */ + for (it = DC->decls_begin(); it != DC->decls_end(); it++) { + Decl *D = *it; + if (FunctionDecl *F = dyn_cast(D)) { + if (fe_.is_rewritable_ext_func(F)) { + map_visitor_.TraverseDecl(D); + } + } + } /** - * ProbeVisitor's traversal runs after an entire translation unit has been parsed. - * to make sure maps with external pointers have been identified. + * In a second traversal, ProbeVisitor tracks pointers passed through the + * maps identified by MapVisitor and replaces their dereferences with calls + * to bpf_probe_read. + * This last traversal runs after MapVisitor went through an entire + * translation unit, to ensure maps with external pointers have all been + * identified. */ for (it = DC->decls_begin(); it != DC->decls_end(); it++) { Decl *D = *it; if (FunctionDecl *F = dyn_cast(D)) { if (fe_.is_rewritable_ext_func(F)) { - for (auto arg : F->parameters()) { - if (arg == F->getParamDecl(0)) { - probe_visitor_.set_ctx(arg); - } else if (!arg->getType()->isFundamentalType()) { - probe_visitor_.set_ptreg(arg); - } - } - probe_visitor_.TraverseDecl(D); + probe_visitor2_.TraverseDecl(D); } } diff --git a/src/cc/frontends/clang/b_frontend_action.h b/src/cc/frontends/clang/b_frontend_action.h index b091a872d9c0..155a9ccfbc18 100644 --- a/src/cc/frontends/clang/b_frontend_action.h +++ b/src/cc/frontends/clang/b_frontend_action.h @@ -88,7 +88,8 @@ class BTypeVisitor : public clang::RecursiveASTVisitor { // Do a depth-first search to rewrite all pointers that need to be probed class ProbeVisitor : public clang::RecursiveASTVisitor { public: - explicit ProbeVisitor(clang::ASTContext &C, clang::Rewriter &rewriter, std::set &m); + explicit ProbeVisitor(clang::ASTContext &C, clang::Rewriter &rewriter, + std::set &m, bool track_helpers); bool VisitVarDecl(clang::VarDecl *Decl); bool VisitCallExpr(clang::CallExpr *Call); bool VisitBinaryOperator(clang::BinaryOperator *E); @@ -96,6 +97,7 @@ class ProbeVisitor : public clang::RecursiveASTVisitor { bool VisitMemberExpr(clang::MemberExpr *E); void set_ptreg(clang::Decl *D) { ptregs_.insert(D); } void set_ctx(clang::Decl *D) { ctx_ = D; } + std::set get_ptregs() { return ptregs_; } private: bool IsContextMemberExpr(clang::Expr *E); clang::SourceRange expansionRange(clang::SourceRange range); @@ -109,19 +111,20 @@ class ProbeVisitor : public clang::RecursiveASTVisitor { std::set ptregs_; std::set &m_; clang::Decl *ctx_; + bool track_helpers_; }; // A helper class to the frontend action, walks the decls class BTypeConsumer : public clang::ASTConsumer { public: explicit BTypeConsumer(clang::ASTContext &C, BFrontendAction &fe, clang::Rewriter &rewriter, std::set &map); - bool HandleTopLevelDecl(clang::DeclGroupRef Group) override; void HandleTranslationUnit(clang::ASTContext &Context) override; private: BFrontendAction &fe_; MapVisitor map_visitor_; BTypeVisitor btype_visitor_; - ProbeVisitor probe_visitor_; + ProbeVisitor probe_visitor1_; + ProbeVisitor probe_visitor2_; }; // Create a B program in 2 phases (everything else is normal C frontend): diff --git a/tests/python/test_clang.py b/tests/python/test_clang.py index e9d0697db6f4..413c1b0df134 100755 --- a/tests/python/test_clang.py +++ b/tests/python/test_clang.py @@ -507,6 +507,65 @@ def test_ext_ptr_maps(self): return 0; }; +int trace_exit(struct pt_regs *ctx) { + u32 pid = bpf_get_current_pid_tgid(); + struct sock **skpp; + skpp = currsock.lookup(&pid); + if (skpp) { + struct sock *skp = *skpp; + return skp->__sk_common.skc_dport; + } + return 0; +} + """ + b = BPF(text=bpf_text) + b.load_func("trace_entry", BPF.KPROBE) + b.load_func("trace_exit", BPF.KPROBE) + + def test_ext_ptr_maps_reverse(self): + bpf_text = """ +#include +#include +#include + +BPF_HASH(currsock, u32, struct sock *); + +int trace_exit(struct pt_regs *ctx) { + u32 pid = bpf_get_current_pid_tgid(); + struct sock **skpp; + skpp = currsock.lookup(&pid); + if (skpp) { + struct sock *skp = *skpp; + return skp->__sk_common.skc_dport; + } + return 0; +} + +int trace_entry(struct pt_regs *ctx, struct sock *sk) { + u32 pid = bpf_get_current_pid_tgid(); + currsock.update(&pid, &sk); + return 0; +}; + """ + b = BPF(text=bpf_text) + b.load_func("trace_entry", BPF.KPROBE) + b.load_func("trace_exit", BPF.KPROBE) + + def test_ext_ptr_maps_indirect(self): + bpf_text = """ +#include +#include +#include + +BPF_HASH(currsock, u32, struct sock *); + +int trace_entry(struct pt_regs *ctx, struct sock *sk) { + u32 pid = bpf_get_current_pid_tgid(); + struct sock **skp = &sk; + currsock.update(&pid, skp); + return 0; +}; + int trace_exit(struct pt_regs *ctx) { u32 pid = bpf_get_current_pid_tgid(); struct sock **skpp;