diff --git a/src/cc/frontends/clang/b_frontend_action.cc b/src/cc/frontends/clang/b_frontend_action.cc index cec7e46946dd..e793ac958ff6 100644 --- a/src/cc/frontends/clang/b_frontend_action.cc +++ b/src/cc/frontends/clang/b_frontend_action.cc @@ -242,7 +242,7 @@ bool ProbeVisitor::assignsExtPtr(Expr *E, int *nbAddrOf) { if (memb_name == "lookup" || memb_name == "lookup_or_init") { if (m_.find(Ref->getDecl()) != m_.end()) { - // Retrieved an ext. pointer from a map, mark LHS as ext. pointer. + // Retrieved an ext. pointer from a map, mark LHS as ext. pointer. // Pointers from maps always need a single dereference to get the // actual value. The value may be an external pointer but cannot // be a pointer to an external pointer as the verifier prohibits @@ -269,7 +269,20 @@ bool ProbeVisitor::VisitVarDecl(VarDecl *D) { } return true; } + bool ProbeVisitor::VisitCallExpr(CallExpr *Call) { + // Skip bpf_probe_read for the third argument if it is an AddrOf. + if (VarDecl *V = dyn_cast(Call->getCalleeDecl())) { + if (V->getName() == "bpf_probe_read" && Call->getNumArgs() >= 3) { + const Expr *E = Call->getArg(2)->IgnoreParenCasts(); + if (const UnaryOperator *UnaryExpr = dyn_cast(E)) { + if (UnaryExpr->getOpcode() == UO_AddrOf) + return false; + } + return true; + } + } + if (FunctionDecl *F = dyn_cast(Call->getCalleeDecl())) { if (F->hasBody()) { unsigned i = 0; diff --git a/tests/python/test_clang.py b/tests/python/test_clang.py index 797db9555ca9..35cabb2de061 100755 --- a/tests/python/test_clang.py +++ b/tests/python/test_clang.py @@ -76,6 +76,24 @@ def test_probe_read2(self): b = BPF(text=text, debug=0) fn = b.load_func("count_foo", BPF.KPROBE) + def test_probe_read3(self): + text = """ +#define KBUILD_MODNAME "foo" +#include +int count_tcp(struct pt_regs *ctx, struct sk_buff *skb) { + // The below define is in net/tcp.h: + // #define TCP_SKB_CB(__skb) ((struct tcp_skb_cb *)&((__skb)->cb[0])) + // Note that it has AddrOf in the macro, which will cause current rewriter + // failing below statement + // return TCP_SKB_CB(skb)->tcp_gso_size; + u16 val = 0; + bpf_probe_read(&val, sizeof(val), &(TCP_SKB_CB(skb)->tcp_gso_size)); + return val; +} +""" + b = BPF(text=text) + fn = b.load_func("count_tcp", BPF.KPROBE) + def test_probe_read_keys(self): text = """ #include diff --git a/tools/tcptop.py b/tools/tcptop.py index 58bfeab1d79c..7d2babb89d07 100755 --- a/tools/tcptop.py +++ b/tools/tcptop.py @@ -118,14 +118,10 @@ def range_check(string): } else if (family == AF_INET6) { struct ipv6_key_t ipv6_key = {.pid = pid}; - bpf_probe_read(&ipv6_key.saddr0, sizeof(ipv6_key.saddr0), - &sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[0]); - bpf_probe_read(&ipv6_key.saddr1, sizeof(ipv6_key.saddr1), - &sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[2]); - bpf_probe_read(&ipv6_key.daddr0, sizeof(ipv6_key.daddr0), - &sk->__sk_common.skc_v6_daddr.in6_u.u6_addr32[0]); - bpf_probe_read(&ipv6_key.daddr1, sizeof(ipv6_key.daddr1), - &sk->__sk_common.skc_v6_daddr.in6_u.u6_addr32[2]); + ipv6_key.saddr0 = *(u64 *)&sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[0]; + ipv6_key.saddr1 = *(u64 *)&sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[2]; + ipv6_key.daddr0 = *(u64 *)&sk->__sk_common.skc_v6_daddr.in6_u.u6_addr32[0]; + ipv6_key.daddr1 = *(u64 *)&sk->__sk_common.skc_v6_daddr.in6_u.u6_addr32[2]; ipv6_key.lport = sk->__sk_common.skc_num; dport = sk->__sk_common.skc_dport; ipv6_key.dport = ntohs(dport); @@ -165,14 +161,10 @@ def range_check(string): } else if (family == AF_INET6) { struct ipv6_key_t ipv6_key = {.pid = pid}; - bpf_probe_read(&ipv6_key.saddr0, sizeof(ipv6_key.saddr0), - &sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[0]); - bpf_probe_read(&ipv6_key.saddr1, sizeof(ipv6_key.saddr1), - &sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[2]); - bpf_probe_read(&ipv6_key.daddr0, sizeof(ipv6_key.daddr0), - &sk->__sk_common.skc_v6_daddr.in6_u.u6_addr32[0]); - bpf_probe_read(&ipv6_key.daddr1, sizeof(ipv6_key.daddr1), - &sk->__sk_common.skc_v6_daddr.in6_u.u6_addr32[2]); + ipv6_key.saddr0 = *(u64 *)&sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[0]; + ipv6_key.saddr1 = *(u64 *)&sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[2]; + ipv6_key.daddr0 = *(u64 *)&sk->__sk_common.skc_v6_daddr.in6_u.u6_addr32[0]; + ipv6_key.daddr1 = *(u64 *)&sk->__sk_common.skc_v6_daddr.in6_u.u6_addr32[2]; ipv6_key.lport = sk->__sk_common.skc_num; dport = sk->__sk_common.skc_dport; ipv6_key.dport = ntohs(dport);