Skip to content

Commit

Permalink
Trace all external pointers passed through a first map (iovisor#1737)
Browse files Browse the repository at this point in the history
* Trace all external pointers going through a first map

Currently, MapVisitor only detects maps with external pointers as
values if the value was directly passed from a function's argument.
For example, in the following, the rewriter is currently unable to
detect currsock has an external pointer as value because an
intermediate variable is used instead of passing directly sk as the
map's value.

    int test(struct pt_regs *ctx, struct sock *sk) {
        u32 pid = bpf_get_current_pid_tgid();
        struct sock **skp = &sk;
        currsock.update(&pid, skp);
        return 0;
    };

With this commit, MapVisitor is able to trace any external pointer
derived from the function's argument and used as a map value. This
commit breaks the ProbeVisitor traversal in two distinct traversals.
The first rewrites dereferences of external pointers originating
from function's arguments and helpers, while the second rewrites only
dereferences of external pointers passed through maps.
Maps with external pointers as values are identified between the two
ProbeVisitor traversals.

* New tests for external pointers passed through maps

test_ext_ptr_maps_reverse ensures dereferences are correctly replaced
even if the update happens after the lookup (in the order of
MapVisitor traversal).
test_ext_ptr_maps_indirect ensures the rewriter is able to trace
external pointers used as map values even if using an intermediate
variable.
  • Loading branch information
pchaigno authored and yonghong-song committed May 8, 2018
1 parent 42da08a commit ad2d0d9
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 36 deletions.
96 changes: 63 additions & 33 deletions src/cc/frontends/clang/b_frontend_action.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ using namespace clang;

class ProbeChecker : public RecursiveASTVisitor<ProbeChecker> {
public:
explicit ProbeChecker(Expr *arg, const set<Decl *> &ptregs)
: needs_probe_(false), is_transitive_(false), ptregs_(ptregs) {
explicit ProbeChecker(Expr *arg, const set<Decl *> &ptregs, bool track_helpers)
: needs_probe_(false), is_transitive_(false), ptregs_(ptregs),
track_helpers_(track_helpers) {
if (arg) {
TraverseStmt(arg);
if (arg->getType()->isPointerType())
Expand All @@ -100,9 +101,10 @@ class ProbeChecker : public RecursiveASTVisitor<ProbeChecker> {
}
bool VisitCallExpr(CallExpr *E) {
needs_probe_ = false;
if (VarDecl *V = dyn_cast<VarDecl>(E->getCalleeDecl())) {
if (!track_helpers_)
return false;
if (VarDecl *V = dyn_cast<VarDecl>(E->getCalleeDecl()))
needs_probe_ = V->getName() == "bpf_get_current_task";
}
return false;
}
bool VisitMemberExpr(MemberExpr *M) {
Expand All @@ -123,6 +125,7 @@ class ProbeChecker : public RecursiveASTVisitor<ProbeChecker> {
bool needs_probe_;
bool is_transitive_;
const set<Decl *> &ptregs_;
bool track_helpers_;
};

// Visit a piece of the AST and mark it as needing probe reads
Expand Down Expand Up @@ -152,7 +155,7 @@ bool MapVisitor::VisitCallExpr(CallExpr *Call) {
return true;

if (memb_name == "update" || memb_name == "insert") {
if (ProbeChecker(Call->getArg(1), ptregs_).needs_probe()) {
if (ProbeChecker(Call->getArg(1), ptregs_, true).needs_probe()) {
m_.insert(Ref->getDecl());
}
}
Expand All @@ -162,12 +165,12 @@ bool MapVisitor::VisitCallExpr(CallExpr *Call) {
return true;
}

ProbeVisitor::ProbeVisitor(ASTContext &C, Rewriter &rewriter, set<Decl *> &m) :
C(C), rewriter_(rewriter), m_(m) {}
ProbeVisitor::ProbeVisitor(ASTContext &C, Rewriter &rewriter, set<Decl *> &m, bool track_helpers) :
C(C), rewriter_(rewriter), m_(m), track_helpers_(track_helpers) {}

bool ProbeVisitor::VisitVarDecl(VarDecl *Decl) {
if (Expr *E = Decl->getInit()) {
if (ProbeChecker(E, ptregs_).is_transitive() || IsContextMemberExpr(E)) {
if (ProbeChecker(E, ptregs_, track_helpers_).is_transitive() || IsContextMemberExpr(E)) {
set_ptreg(Decl);
}
}
Expand All @@ -178,7 +181,7 @@ bool ProbeVisitor::VisitCallExpr(CallExpr *Call) {
if (F->hasBody()) {
unsigned i = 0;
for (auto arg : Call->arguments()) {
if (ProbeChecker(arg, ptregs_).needs_probe())
if (ProbeChecker(arg, ptregs_, track_helpers_).needs_probe())
ptregs_.insert(F->getParamDecl(i));
++i;
}
Expand All @@ -194,7 +197,7 @@ bool ProbeVisitor::VisitBinaryOperator(BinaryOperator *E) {
if (!E->isAssignmentOp())
return true;
// copy probe attribute from RHS to LHS if present
if (ProbeChecker(E->getRHS(), ptregs_).is_transitive()) {
if (ProbeChecker(E->getRHS(), ptregs_, track_helpers_).is_transitive()) {
ProbeSetter setter(&ptregs_);
setter.TraverseStmt(E->getLHS());
} else if (E->getRHS()->getStmtClass() == Stmt::CallExprClass) {
Expand Down Expand Up @@ -227,7 +230,7 @@ bool ProbeVisitor::VisitUnaryOperator(UnaryOperator *E) {
return true;
if (memb_visited_.find(E) != memb_visited_.end())
return true;
if (!ProbeChecker(E, ptregs_).needs_probe())
if (!ProbeChecker(E, ptregs_, track_helpers_).needs_probe())
return true;
memb_visited_.insert(E);
Expr *sub = E->getSubExpr();
Expand Down Expand Up @@ -264,7 +267,7 @@ bool ProbeVisitor::VisitMemberExpr(MemberExpr *E) {

// Checks to see if the expression references something that needs to be run
// through bpf_probe_read.
if (!ProbeChecker(base, ptregs_).needs_probe())
if (!ProbeChecker(base, ptregs_, track_helpers_).needs_probe())
return true;

string rhs = rewriter_.getRewrittenText(expansionRange(SourceRange(rhs_start, E->getLocEnd())));
Expand Down Expand Up @@ -889,44 +892,71 @@ BTypeConsumer::BTypeConsumer(ASTContext &C, BFrontendAction &fe,
: fe_(fe),
map_visitor_(m),
btype_visitor_(C, fe),
probe_visitor_(C, rewriter, m) {}
probe_visitor1_(C, rewriter, m, true),
probe_visitor2_(C, rewriter, m, false) {}

bool BTypeConsumer::HandleTopLevelDecl(DeclGroupRef Group) {
for (auto D : Group) {
void BTypeConsumer::HandleTranslationUnit(ASTContext &Context) {
DeclContext::decl_iterator it;
DeclContext *DC = TranslationUnitDecl::castToDeclContext(Context.getTranslationUnitDecl());

/**
* In a first traversal, ProbeVisitor tracks external pointers identified
* through each function's arguments and replaces their dereferences with
* calls to bpf_probe_read. It also passes all identified pointers to
* external addresses to MapVisitor.
*/
for (it = DC->decls_begin(); it != DC->decls_end(); it++) {
Decl *D = *it;
if (FunctionDecl *F = dyn_cast<FunctionDecl>(D)) {
if (fe_.is_rewritable_ext_func(F)) {
for (auto arg : F->parameters()) {
if (arg != F->getParamDecl(0) && !arg->getType()->isFundamentalType()) {
map_visitor_.set_ptreg(arg);
if (arg == F->getParamDecl(0)) {
probe_visitor1_.set_ctx(arg);
} else if (!arg->getType()->isFundamentalType()) {
probe_visitor1_.set_ptreg(arg);
}
}
map_visitor_.TraverseDecl(D);

probe_visitor1_.TraverseDecl(D);
for (auto ptreg : probe_visitor1_.get_ptregs()) {
map_visitor_.set_ptreg(ptreg);
}
}
}
}
return true;
}

void BTypeConsumer::HandleTranslationUnit(ASTContext &Context) {
DeclContext::decl_iterator it;
DeclContext *DC = TranslationUnitDecl::castToDeclContext(Context.getTranslationUnitDecl());
/**
* MapVisitor uses external pointers identified by the first ProbeVisitor
* traversal to identify all maps with external pointers as values.
* MapVisitor runs only after ProbeVisitor finished its traversal of the
* whole translation unit to clearly separate the role of each ProbeVisitor's
* traversal: the first tracks external pointers from function arguments,
* whereas the second tracks external pointers from maps. Without this clear
* separation, ProbeVisitor might attempt to replace several times the same
* dereferences.
*/
for (it = DC->decls_begin(); it != DC->decls_end(); it++) {
Decl *D = *it;
if (FunctionDecl *F = dyn_cast<FunctionDecl>(D)) {
if (fe_.is_rewritable_ext_func(F)) {
map_visitor_.TraverseDecl(D);
}
}
}

/**
* ProbeVisitor's traversal runs after an entire translation unit has been parsed.
* to make sure maps with external pointers have been identified.
* In a second traversal, ProbeVisitor tracks pointers passed through the
* maps identified by MapVisitor and replaces their dereferences with calls
* to bpf_probe_read.
* This last traversal runs after MapVisitor went through an entire
* translation unit, to ensure maps with external pointers have all been
* identified.
*/
for (it = DC->decls_begin(); it != DC->decls_end(); it++) {
Decl *D = *it;
if (FunctionDecl *F = dyn_cast<FunctionDecl>(D)) {
if (fe_.is_rewritable_ext_func(F)) {
for (auto arg : F->parameters()) {
if (arg == F->getParamDecl(0)) {
probe_visitor_.set_ctx(arg);
} else if (!arg->getType()->isFundamentalType()) {
probe_visitor_.set_ptreg(arg);
}
}
probe_visitor_.TraverseDecl(D);
probe_visitor2_.TraverseDecl(D);
}
}

Expand Down
9 changes: 6 additions & 3 deletions src/cc/frontends/clang/b_frontend_action.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,16 @@ class BTypeVisitor : public clang::RecursiveASTVisitor<BTypeVisitor> {
// Do a depth-first search to rewrite all pointers that need to be probed
class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
public:
explicit ProbeVisitor(clang::ASTContext &C, clang::Rewriter &rewriter, std::set<clang::Decl *> &m);
explicit ProbeVisitor(clang::ASTContext &C, clang::Rewriter &rewriter,
std::set<clang::Decl *> &m, bool track_helpers);
bool VisitVarDecl(clang::VarDecl *Decl);
bool VisitCallExpr(clang::CallExpr *Call);
bool VisitBinaryOperator(clang::BinaryOperator *E);
bool VisitUnaryOperator(clang::UnaryOperator *E);
bool VisitMemberExpr(clang::MemberExpr *E);
void set_ptreg(clang::Decl *D) { ptregs_.insert(D); }
void set_ctx(clang::Decl *D) { ctx_ = D; }
std::set<clang::Decl *> get_ptregs() { return ptregs_; }
private:
bool IsContextMemberExpr(clang::Expr *E);
clang::SourceRange expansionRange(clang::SourceRange range);
Expand All @@ -109,19 +111,20 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
std::set<clang::Decl *> ptregs_;
std::set<clang::Decl *> &m_;
clang::Decl *ctx_;
bool track_helpers_;
};

// A helper class to the frontend action, walks the decls
class BTypeConsumer : public clang::ASTConsumer {
public:
explicit BTypeConsumer(clang::ASTContext &C, BFrontendAction &fe, clang::Rewriter &rewriter, std::set<clang::Decl *> &map);
bool HandleTopLevelDecl(clang::DeclGroupRef Group) override;
void HandleTranslationUnit(clang::ASTContext &Context) override;
private:
BFrontendAction &fe_;
MapVisitor map_visitor_;
BTypeVisitor btype_visitor_;
ProbeVisitor probe_visitor_;
ProbeVisitor probe_visitor1_;
ProbeVisitor probe_visitor2_;
};

// Create a B program in 2 phases (everything else is normal C frontend):
Expand Down
59 changes: 59 additions & 0 deletions tests/python/test_clang.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,65 @@ def test_ext_ptr_maps(self):
return 0;
};
int trace_exit(struct pt_regs *ctx) {
u32 pid = bpf_get_current_pid_tgid();
struct sock **skpp;
skpp = currsock.lookup(&pid);
if (skpp) {
struct sock *skp = *skpp;
return skp->__sk_common.skc_dport;
}
return 0;
}
"""
b = BPF(text=bpf_text)
b.load_func("trace_entry", BPF.KPROBE)
b.load_func("trace_exit", BPF.KPROBE)

def test_ext_ptr_maps_reverse(self):
bpf_text = """
#include <uapi/linux/ptrace.h>
#include <net/sock.h>
#include <bcc/proto.h>
BPF_HASH(currsock, u32, struct sock *);
int trace_exit(struct pt_regs *ctx) {
u32 pid = bpf_get_current_pid_tgid();
struct sock **skpp;
skpp = currsock.lookup(&pid);
if (skpp) {
struct sock *skp = *skpp;
return skp->__sk_common.skc_dport;
}
return 0;
}
int trace_entry(struct pt_regs *ctx, struct sock *sk) {
u32 pid = bpf_get_current_pid_tgid();
currsock.update(&pid, &sk);
return 0;
};
"""
b = BPF(text=bpf_text)
b.load_func("trace_entry", BPF.KPROBE)
b.load_func("trace_exit", BPF.KPROBE)

def test_ext_ptr_maps_indirect(self):
bpf_text = """
#include <uapi/linux/ptrace.h>
#include <net/sock.h>
#include <bcc/proto.h>
BPF_HASH(currsock, u32, struct sock *);
int trace_entry(struct pt_regs *ctx, struct sock *sk) {
u32 pid = bpf_get_current_pid_tgid();
struct sock **skp = &sk;
currsock.update(&pid, skp);
return 0;
};
int trace_exit(struct pt_regs *ctx) {
u32 pid = bpf_get_current_pid_tgid();
struct sock **skpp;
Expand Down

0 comments on commit ad2d0d9

Please sign in to comment.