Skip to content

Commit

Permalink
Handle intra-symbol jmps, including call optimisation jmps.
Browse files Browse the repository at this point in the history
  • Loading branch information
gm281 committed Aug 5, 2014
1 parent 910535e commit 009970c
Showing 1 changed file with 170 additions and 80 deletions.
250 changes: 170 additions & 80 deletions trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ def __init__(self, wait_event, notify_event, listener, process, result):
self.process = process
self.result = result
self.exiting = False
self.wait_timeout = False

def isExiting(self):
return self.exiting
def wait_timed_out(self):
return self.wait_timeout

def exit(self):
self.exiting = True
Expand All @@ -33,17 +34,17 @@ def run(self):
return
while True:
event = lldb.SBEvent()
print >>self.result, 'Listener thread waiting for an event'
#print >>self.result, 'Listener thread waiting for an event'
wait_result = self.listener.WaitForEvent(10, event)

if not wait_result:
print >>self.result, 'Listener thread timed out waiting for notification'
self.exiting = True
self.wait_timeout = True
self.notify_event.set()
return
print >>self.result, '=== YEY'
print >>self.result, 'Event data flavor:', event.GetDataFlavor()
print >>self.result, 'Event string:', lldb.SBEvent.GetCStringFromEvent(event)
break
#print >>self.result, '=== YEY'
#print >>self.result, 'Event data flavor:', event.GetDataFlavor()
#print >>self.result, 'Event string:', lldb.SBEvent.GetCStringFromEvent(event)
if self.process.GetState() == lldb.eStateStopped:
break
print >>self.result, 'Process not stopped, listening for the next event'
Expand All @@ -57,85 +58,119 @@ def __init__(self, target, thread, frame, result):
self.frame = frame
self.result = result
self.return_breakpoint = None
self.call_breakpoints = {}
self.jmp_breakpoints = {}
self.subsequent_instruction = {}

def update_frame(self, frame):
self.frame = frame

def instrument_calls(self):
def instrument_calls_and_jmps(self):
# TODO: symbols vs functions
print >>self.result, self.frame.GetFunction()
symbol = self.frame.GetSymbol()
print >>self.result, "=========> Instrumenting symbol:"
print >>self.result, symbol
start_address = symbol.GetStartAddress()
print >>self.result, start_address
print >>self.result, '0x%x' % start_address.GetLoadAddress(self.target)
start_address = symbol.GetStartAddress().GetLoadAddress(self.target)
#print >>self.result, '0x%x' % start_address
end_address = symbol.GetEndAddress().GetLoadAddress(self.target)
#print >>self.result, '0x%x' % end_address
instruction_list = symbol.GetInstructions(self.target)
#print >>self.result, instruction_list
breakpoints = {}
subsequent_instruction = {}
previous_breakpoint_address = 0L
for i in instruction_list:
address = i.GetAddress().GetLoadAddress(self.target)
print >>self.result, '0x%x' % address
print >>self.result, '{}, {}, {}'.format(i.GetMnemonic(self.target), i.GetOperands(self.target), i.GetComment(self.target))
#print >>self.result, '0x%x' % address
#print >>self.result, '{}, {}, {}'.format(i.GetMnemonic(self.target), i.GetOperands(self.target), i.GetComment(self.target))
if address in self.call_breakpoints or address in self.jmp_breakpoints:
#print >>self.result, 'There already is a breakpoint for this address'
continue
if previous_breakpoint_address != 0L:
subsequent_instruction[previous_breakpoint_address] = address
self.subsequent_instruction[previous_breakpoint_address] = address
previous_breakpoint_address = 0L
mnemonic = i.GetMnemonic(self.target)
if mnemonic != None and mnemonic.startswith('call'):
print >>self.result, 'Putting breakpoint at 0x%lx' % address
#print >>self.result, 'Putting breakpoint at 0x%lx' % address
breakpoint = self.target.BreakpointCreateByAddress(address)
breakpoint.SetThreadID(self.thread.GetThreadID())
breakpoints[address] = breakpoint
self.call_breakpoints[address] = breakpoint
previous_breakpoint_address = address
print >>self.result, breakpoints
self.breakpoints = breakpoints
self.subsequent_instruction = subsequent_instruction
if mnemonic != None and mnemonic.startswith('jmp'):
try:
jmp_destination = int(i.GetOperands(self.target), 16)
except:
jmp_destination = 0L;

if jmp_destination < start_address or jmp_destination >= end_address:
#print >>self.result, 'Non-Local call'
breakpoint = self.target.BreakpointCreateByAddress(address)
breakpoint.SetThreadID(self.thread.GetThreadID())
self.jmp_breakpoints[address] = breakpoint

def clear_calls_instrumentation(self):
for breakpoint in self.breakpoints.itervalues():
print >>self.result, 'Deleting breakpoint %d' % breakpoint.GetID()
for breakpoint in self.call_breakpoints.itervalues():
#print >>self.result, 'Deleting breakpoint %d' % breakpoint.GetID()
self.target.BreakpointDelete(breakpoint.GetID())
self.call_breakpoints = {}
self.subsequent_instruction = {}

def clear_jmps_instrumentation(self):
for breakpoint in self.jmp_breakpoints.itervalues():
#print >>self.result, 'Deleting breakpoint %d' % breakpoint.GetID()
self.target.BreakpointDelete(breakpoint.GetID())
self.breakpoints = None
self.subsequent_instruction = None
self.jmp_breakpoints = {}

def clear_return_breakpoint(self):
self.target.BreakpointDelete(self.return_breakpoint.GetID())
self.return_breakpoint == None

def instrument_return(self, return_address):
self.return_address = return_address
self.return_breakpoint = self.target.BreakpointCreateByAddress(self.return_address)
self.return_breakpoint.SetThreadID(self.thread.GetThreadID())

def is_stopped_on_call_and_instrument_return(self, frame):
def is_stopped_on_call(self, frame):
if frame.GetFrameID() != self.frame.GetFrameID():
print >>self.result, "Frames don't match"
print >>self.result, "A Frames don't match, ours: {}, valid: {}, submitted: {}".format(self.frame.GetFrameID(), self.frame.IsValid(), frame.GetFrameID())
return False

stop_address = frame.GetPC()
print >>self.result, 'Stop address: 0x%lx' % stop_address
if not stop_address in self.breakpoints:
return False
return stop_address in self.call_breakpoints

breakpoint = self.breakpoints[stop_address]
print >>self.result, 'Stopped on breakpoint:'
print >>self.result, breakpoint
if not stop_address in self.subsequent_instruction:
print >>self.result, "Couldn't find subsequent instruction"
def is_stopped_on_jmp(self, frame, validate_saved_frame):
if validate_saved_frame and frame.GetFrameID() != self.frame.GetFrameID():
print >>self.result, "B Frames don't match, ours: {}, valid: {}, submitted: {}".format(self.frame.GetFrameID(), self.frame.IsValid(), frame.GetFrameID())
return False
self.instrument_return(self.subsequent_instruction[stop_address])
self.clear_calls_instrumentation()
return True

stop_address = frame.GetPC()
return stop_address in self.jmp_breakpoints

def is_stopped_on_return(self, frame):
if frame.GetFrameID() != self.frame.GetFrameID():
print >>self.result, "Frames don't match"
print >>self.result, "C Frames don't match, ours: {}, valid: {}, submitted: {}".format(self.frame.GetFrameID(), self.frame.IsValid(), frame.GetFrameID())
return False

if self.return_breakpoint == None:
return False

stop_address = frame.GetPC()
return self.return_address == stop_address

def instrument_return(self, return_address):
self.return_address = return_address
self.return_breakpoint = self.target.BreakpointCreateByAddress(self.return_address)
self.return_breakpoint.SetThreadID(self.thread.GetThreadID())

def clear_calls_and_jmps_and_instrument_return(self, frame):
stop_address = frame.GetPC()
if not stop_address in self.subsequent_instruction:
print >>self.result, "Couldn't find subsequent instruction"
return False
self.instrument_return(self.subsequent_instruction[stop_address])
self.clear_calls_instrumentation()
self.clear_jmps_instrumentation()
return True

def clear(self):
if self.breakpoints != None:
if self.call_breakpoints != None:
self.clear_calls_instrumentation()
if self.jmp_breakpoints != None:
self.clear_jmps_instrumentation()
if self.return_breakpoint != None:
self.clear_return_breakpoint()

Expand All @@ -150,7 +185,7 @@ def continue_and_wait_for_breakpoint(process, thread, listening_thread, wait_eve
print >>result, 'Got notification, sanity checks follow'
print >>result, process.GetState()
# Some sanity checking
if listening_thread.isExiting():
if listening_thread.wait_timed_out():
print >>result, 'Listener thread exited unexpectedly'
return False
if thread.GetStopReason() != lldb.eStopReasonBreakpoint:
Expand All @@ -159,6 +194,28 @@ def continue_and_wait_for_breakpoint(process, thread, listening_thread, wait_eve
return False
return True

def get_pc_addresses(thread):
def GetPCAddress(i):
return thread.GetFrameAtIndex(i).GetPCAddress()

return map(GetPCAddress, range(thread.GetNumFrames()))

def print_stacktrace(result, target, thread):
depth = thread.GetNumFrames()
addrs = get_pc_addresses(thread)
for i in range(depth):
frame = thread.GetFrameAtIndex(i)
function = frame.GetFunction()

load_addr = addrs[i].GetLoadAddress(target)
if not function:
file_addr = addrs[i].GetFileAddress()
start_addr = frame.GetSymbol().GetStartAddress().GetFileAddress()
symbol_offset = file_addr - start_addr
print >>result, ' frame #{num}: {addr:#016x} `{symbol} + {offset}'.format(num=i, addr=load_addr, symbol=frame.GetSymbol().GetName(), offset=symbol_offset)
else:
print >>result, ' frame #{num}: {addr:#016x} `{func}'.format(num=i, addr=load_addr, func=frame.GetFunctionName())

def trace(debugger, command, result, internal_dict):
wait_event = threading.Event()
wait_event.clear()
Expand Down Expand Up @@ -194,47 +251,80 @@ def trace(debugger, command, result, internal_dict):
while True:
if instrumented_frame == None:
instrumented_frame = InstrumentedFrame(target, thread, frame, result)
instrumented_frame.instrument_calls_and_jmps()

instrumented_frame.instrument_calls()
print >>result, 'Instrumented all calls, running the process'

success = continue_and_wait_for_breakpoint(process, thread, my_thread, wait_event, notify_event, result)
if not success:
print >>result, "Failed to continue+stop the process"
return
print >>result, 'Running the process'
# Continue running until next breakpoint is hit, _unless_ PC is already on a breakpoint address
if not instrumented_frame.is_stopped_on_call(frame) and not instrumented_frame.is_stopped_on_jmp(frame, True):
success = continue_and_wait_for_breakpoint(process, thread, my_thread, wait_event, notify_event, result)
if not success:
print >>result, "Failed to continue+stop the process"
break

frame = thread.GetFrameAtIndex(0)
print >>result, "==================================================================================="
print >>result, "=================== Stopped at: ===================="
print >>result, frame
print >>result, frame.GetSymbol()
print >>result, "0x%lx" % frame.GetPC()
print_stacktrace(result, target, thread)
print >>result, "===="

if len(instrumented_frames) > 0:
parent_instrumented_frame = instrumented_frames[-1]
else:
parent_instrumented_frame = None

success = instrumented_frame.is_stopped_on_call_and_instrument_return(frame)
if not success:
print >>result, "Failed to instrument call. Trying popping instrumented frames."
# Clear current frame of all (call) breakpoints
# Check for return from call first, then for call and finally for jmp.
# That way, we can be lenient about checking whether the frame saved
# in the jmp instrumented frame is still valid.
# This is difficult in case of optimised calls, where call instruction
# is replaced with:
# popq %rbp
# jmpq $destination
# (this optimisation is used in tail recursion optimisation and
# tail returns of the same type, where the compiler can squash
# one frame away)
# Since this optimisation pops %rbp (which then gets pushed in the
# preamble of $destination), at the time of jmp, the caller frame
# isn't really present. This has the effect of invalidating SBFrame
# stored by the current instrumented_frame.
# Taking the above into account, the best we can do is to check for
# return first and if that's not the case, we know we must be in the
# same logical frame, therefore when checking for jmps, it's enough
# to verify the address.
if parent_instrumented_frame != None and parent_instrumented_frame.is_stopped_on_return(frame):
print >>result, "Stopped on return, popping a frame"
instrumented_frame.clear()
while len(instrumented_frames) > 0:
instrumented_frame = instrumented_frames.pop()
print >>result, "Checking whether to pop a frame"
if instrumented_frame.is_stopped_on_return(frame):
print >>result, "Found return frame"
instrumented_frame.clear_return_breakpoint()
success = True
break
print >>result, "Clearing popped frame"
instrumented_frame.clear()
if success and len(instrumented_frames) > 0:
continue
else:
print >>result, "Run out of frames to pop, exiting"
instrumented_frame = instrumented_frames.pop()
instrumented_frame.clear_return_breakpoint()
if len(instrumented_frames) == 0:
print >>result, "Detected return from the function under trace, exiting"
break
instrumented_frame.instrument_calls_and_jmps()
elif instrumented_frame.is_stopped_on_call(frame):
print >>result, "Stopped on call"
success = instrumented_frame.clear_calls_and_jmps_and_instrument_return(frame)
if not success:
break
thread.StepInstruction(False)
instrumented_frames.append(instrumented_frame)
instrumented_frame = None
frame = thread.GetFrameAtIndex(0)
print >>result, 'Entered new frame at: 0x%lx' % frame.GetPC()
elif instrumented_frame.is_stopped_on_jmp(frame, False):
print >>result, "Stopped on jmp"
thread.StepInstruction(False)
frame = thread.GetFrameAtIndex(0)
instrumented_frame.update_frame(frame)
instrumented_frame.instrument_calls_and_jmps()
else:
print >>result, "Failed to detect return, call or jmp. Error exit"
break

thread.StepInstruction(False)
instrumented_frames.append(instrumented_frame)
instrumented_frame = None
frame = thread.GetFrameAtIndex(0)
stop_address = frame.GetPC()
print >>result, 'Entered new frame at: 0x%lx' % stop_address

# TODO: clear instrumented frames, on errors there
# may be breakpoints left, what needs to be worked out
# is whether instrumented_frame is set, and whether
# it needs clearing
my_thread.exit()
wait_event.set()
my_thread.join()
Expand Down

0 comments on commit 009970c

Please sign in to comment.