Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Restructuring] Wrong Value in Condition #93

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion decompiler/frontend/binaryninja/handlers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def lift_constant_pointer(self, pointer: mediumlevelil.MediumLevelILConstPtr, **
if symbol is not None and symbol.type in (SymbolType.ImportedFunctionSymbol, SymbolType.ExternalSymbol, SymbolType.FunctionSymbol):
return self._lift_symbol_pointer(address, symbol)

if not isinstance(pointer, mediumlevelil.MediumLevelILImport) and (symbol is None or symbol.type != SymbolType.DataSymbol) and (string := bv.get_string_at(address, partial=True) or bv.get_ascii_string_at(address, min_length=2)):
if (
not isinstance(pointer, mediumlevelil.MediumLevelILImport)
and (symbol is None or symbol.type != SymbolType.DataSymbol)
and (string := bv.get_string_at(address, partial=True) or bv.get_ascii_string_at(address, min_length=2))
):
return Constant(address, Pointer(Integer.char()), Constant(string.value, Integer.char()))

if (variable := bv.get_data_var_at(address)) is not None:
Expand Down
4 changes: 4 additions & 0 deletions decompiler/pipeline/controlflowanalysis/restructuring.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import logging
from typing import List, Optional

from decompiler.pipeline.commons.reaching_definitions import ReachingDefinitions
from decompiler.pipeline.controlflowanalysis.restructuring_commons.acyclic_restructuring import AcyclicRegionRestructurer
from decompiler.pipeline.controlflowanalysis.restructuring_commons.cyclic_restructuring import CyclicRegionStructurer
from decompiler.pipeline.controlflowanalysis.restructuring_commons.empty_basic_block_remover import EmptyBasicBlockRemover
from decompiler.pipeline.controlflowanalysis.restructuring_commons.side_effect_handling.side_effect_handler import SideEffectHandler
from decompiler.pipeline.stage import PipelineStage
from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
Expand All @@ -33,6 +35,7 @@ def __init__(self, tcfg: Optional[TransitionCFG] = None, asforest: Optional[Abst
"""
self.t_cfg: TransitionCFG = tcfg
self.asforest: AbstractSyntaxForest = asforest
self._reaching_definitions: Optional[ReachingDefinitions] = None

def run(self, task: DecompilerTask):
"""
Expand All @@ -48,6 +51,7 @@ def run(self, task: DecompilerTask):
self.asforest.set_current_root(self.t_cfg.root.ast)
assert (roots := len(self.asforest.get_roots)) == 1, f"After the restructuring the forest should have one root, but it has {roots}!"
task._ast = AbstractSyntaxTree.from_asforest(self.asforest, self.asforest.current_root)
SideEffectHandler.resolve(task)
task._cfg = None

def restructure_cfg(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from __future__ import annotations

from typing import Dict, Iterable, Optional, Type, Union

from decompiler.pipeline.controlflowanalysis.restructuring_commons.side_effect_handling.data_graph_visitor import (
ASTDataGraphVisitor,
SubtreeProperty,
)
from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, CodeNode, ConditionNode, LoopNode, SwitchNode, TrueNode
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
from decompiler.structures.graphs.basicblock import BasicBlock
from decompiler.structures.graphs.branches import BasicBlockEdge, FalseCase, SwitchCase, TrueCase, UnconditionalEdge
from decompiler.structures.graphs.cfg import ControlFlowGraph
from decompiler.structures.logic.logic_condition import LogicCondition
from decompiler.structures.pseudo import Assignment, Branch, Call, Condition, Constant, ListOperation, Variable
from networkx import MultiDiGraph


class DataNode(BasicBlock):
def copy(self) -> DataNode:
"""Return a deep copy of the node."""
return DataNode(self._address, [instruction.copy() for instruction in self._instructions], graph=self._graph)

@classmethod
def generate_node_from(
cls, idx: int, ast_node: AbstractSyntaxTreeNode, condition_map: Dict[LogicCondition, Condition]
) -> Optional[DataNode]:
if isinstance(ast_node, CodeNode):
return DataNode(idx, ast_node.instructions)
if isinstance(ast_node, ConditionNode):
return LogicNode(idx, ast_node, condition_map)
if isinstance(ast_node, SwitchNode):
return DataNode(idx, [Assignment(ListOperation([]), Call(Variable("switch"), [ast_node.expression]))])
if isinstance(ast_node, LoopNode):
return LogicNode(idx, ast_node, condition_map)
if isinstance(ast_node, CaseNode):
constant = ast_node.constant if isinstance(ast_node.constant, Constant) else Constant(ast_node.constant)
return DataNode(idx, [Assignment(ListOperation([]), Call(Variable("case"), [constant]))])
return None


class LogicNode(DataNode):
def __init__(
self,
name: int,
ast_node: Union[LoopNode, ConditionNode],
condition_map: Dict[LogicCondition, Condition],
graph: ControlFlowGraph = None,
):
super().__init__(name, [Branch(condition_map[symbol]) for symbol in ast_node.condition.get_symbols()], graph)
self._logic_condition = ast_node.condition
self._condition_map = condition_map
self._type = ast_node.loop_type.value if isinstance(ast_node, LoopNode) else "if"

def __str__(self) -> str:
"""Return a string representation of all instructions in the basic block."""
return f"{self._type}({self._logic_condition.rich_string_representation(self._condition_map)})"

def __eq__(self, other: object) -> bool:
"""Basic Blocks can be equal based on their contained instructions and addresses."""
return isinstance(other, LogicNode) and self._address == other._address

def __hash__(self) -> int:
"""
Basic Blocks should hash the same even then in different graphs.

Since addresses are supposed to be unique,
they are used for hashing in order to identify the same Block with different instruction as equal.
"""
return hash(self._address)

@property
def logic_condition(self) -> LogicCondition:
return self._logic_condition

def copy(self) -> LogicNode:
"""Return a deep copy of the node."""
return LogicNode(self._address, self._logic_condition, self._condition_map, graph=self._graph)


class DataGraph(ControlFlowGraph):
NODE = DataNode
EDGE = BasicBlockEdge

def __init__(self, graph: Optional[MultiDiGraph] = None, root: Optional[NODE] = None):
"""
Init a new empty instance.

- translation_dict maps the AST-nodes to the data-nodes of the graph
"""
super().__init__(graph, root)
self._translation_dict: Dict[AbstractSyntaxTreeNode, DataNode] = dict()

@classmethod
def generate_from_ast(cls, ast: AbstractSyntaxTree) -> DataGraph:
data_graph = cls()
property_dict = data_graph.generate_nodes(ast)
data_graph.generate_edges(ast, property_dict)
return data_graph

def generate_nodes(self, ast: AbstractSyntaxTree) -> Dict[AbstractSyntaxTreeNode, SubtreeProperty]:
"""Generate nodes from the given ast for the data graph."""
ast_data_graph_visitor: ASTDataGraphVisitor = ASTDataGraphVisitor()
idx = 0
for ast_node in ast.post_order():
base_node = DataNode.generate_node_from(idx, ast_node, ast.condition_map)
if base_node is not None:
self.add_node(base_node)
self._translation_dict[ast_node] = base_node
idx += 1
ast_data_graph_visitor.visit(ast_node)
return ast_data_graph_visitor.property_dict

def generate_edges(self, ast: AbstractSyntaxTree, property_dict: Dict[AbstractSyntaxTreeNode, SubtreeProperty]) -> None:
"""Generate edges between the data-graph nodes"""
for seq_node in ast.get_sequence_nodes_post_order():
for source_child, sink_child in zip(seq_node.children[:-1], seq_node.children[1:]):
self._add_edges_between(property_dict[source_child].last_nodes, {property_dict[sink_child].first_node})
for cond_node in ast.get_condition_nodes_post_order():
for branch in cond_node.children:
edge_type = TrueCase if isinstance(branch, TrueNode) else FalseCase
self._add_edges_between({cond_node}, {property_dict[branch.child].first_node}, edge_type)
for loop_node in ast.get_loop_nodes_post_order():
self._add_edges_between(property_dict[loop_node.body].continue_nodes | property_dict[loop_node.body].last_nodes, {loop_node})
self._add_edges_between({loop_node}, {property_dict[loop_node.body].first_node}, TrueCase)
for switch_node in ast.get_switch_nodes_post_order():
self._add_edges_between({switch_node}, switch_node.children, SwitchCase)
for source_case, sink_case in zip(switch_node.children[:-1], switch_node.children[1:]):
if source_case.break_case:
continue
self._add_edges_between(property_dict[source_case].last_nodes, {property_dict[sink_case].first_node})
# edges for case_nodes
for case in switch_node.children:
self._add_edges_between({case}, {property_dict[case.child].first_node})

def _add_edges_between(
self,
sources: Iterable[AbstractSyntaxTreeNode],
sinks: Iterable[AbstractSyntaxTreeNode],
edge_type: Type[BasicBlockEdge] = UnconditionalEdge,
):
"""Add edges between the corresponding base-nodes of the source nodes and the base-nodes of the sink nodes."""
for source in sources:
for sink in sinks:
if edge_type == SwitchCase:
assert isinstance(sink, CaseNode)
self.add_edge(SwitchCase(self._translation_dict[source], self._translation_dict[sink], [sink.constant]))
elif edge_type == UnconditionalEdge and isinstance(source, (ConditionNode, LoopNode)):
self.add_edge(FalseCase(self._translation_dict[source], self._translation_dict[sink]))
else:
self.add_edge(edge_type(self._translation_dict[source], self._translation_dict[sink]))

def get_logic_nodes(self) -> Iterable[LogicNode]:
"""Yield all logic nodes of the data-graph"""
for node in self.nodes:
if isinstance(node, LogicNode):
yield node
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, Set

from decompiler.structures.ast.ast_nodes import (
AbstractSyntaxTreeNode,
CaseNode,
CodeNode,
ConditionNode,
DoWhileLoopNode,
FalseNode,
LoopNode,
SeqNode,
SwitchNode,
TrueNode,
VirtualRootNode,
)
from decompiler.structures.visitors.interfaces import ASTVisitorInterface


@dataclass
class SubtreeProperty:
root: AbstractSyntaxTreeNode
first_node: AbstractSyntaxTreeNode
last_nodes: Set[AbstractSyntaxTreeNode] = field(default_factory=set)
continue_nodes: Set[AbstractSyntaxTreeNode] = field(default_factory=set)
break_nodes: Set[AbstractSyntaxTreeNode] = field(default_factory=set)


class ASTDataGraphVisitor(ASTVisitorInterface):
def __init__(self):
super().__init__()
self._property_dict: Dict[AbstractSyntaxTreeNode, SubtreeProperty] = dict()

@property
def property_dict(self) -> Dict[AbstractSyntaxTreeNode, SubtreeProperty]:
return self._property_dict

def visit_seq_node(self, node: SeqNode) -> None:
first_node = self._property_dict[node.children[0]].first_node
last_nodes = self._property_dict[node.children[-1]].last_nodes
continue_nodes = set().union(*(self._property_dict[child].continue_nodes for child in node.children))
break_nodes = set().union(*(self._property_dict[child].break_nodes for child in node.children))
self._property_dict[node] = SubtreeProperty(node, first_node, last_nodes, continue_nodes, break_nodes)

def visit_code_node(self, node: CodeNode) -> None:
first_node = node
last_nodes = set() if node.does_end_with_break else {node}
continue_nodes = {node} if node.does_end_with_continue else set()
break_nodes = {node} if node.does_end_with_break else set()
self._property_dict[node] = SubtreeProperty(node, first_node, last_nodes, continue_nodes, break_nodes)

def visit_condition_node(self, node: ConditionNode) -> None:
first_node = node
last_nodes = {node} if node.false_branch_child is None else set()
last_nodes = last_nodes.union(*(self._property_dict[branch.child].last_nodes for branch in node.children))
continue_nodes = set().union(*(self._property_dict[branch.child].continue_nodes for branch in node.children))
break_nodes = set().union(*(self._property_dict[branch.child].break_nodes for branch in node.children))
self._property_dict[node] = SubtreeProperty(node, first_node, last_nodes, continue_nodes, break_nodes)

def visit_loop_node(self, node: LoopNode) -> None:
first_node = self._property_dict[node.body].first_node if isinstance(node, DoWhileLoopNode) else node
last_nodes = set() if node.is_endless_loop else {node}
last_nodes |= self._property_dict[node.body].break_nodes
continue_nodes = set()
break_nodes = set()
self._property_dict[node] = SubtreeProperty(node, first_node, last_nodes, continue_nodes, break_nodes)

def visit_switch_node(self, node: SwitchNode) -> None:
first_node = node
last_nodes = self._property_dict[node.default].last_nodes if node.default else set()
last_nodes = last_nodes.union(*(self._property_dict[case].last_nodes for case in node.cases if case.break_case))
continue_nodes = set().union(*(self._property_dict[case].continue_nodes for case in node.children))
break_nodes = set()
self._property_dict[node] = SubtreeProperty(node, first_node, last_nodes, continue_nodes, break_nodes)

def visit_case_node(self, node: CaseNode) -> None:
first_node = node
last_nodes = self._property_dict[node.child].last_nodes
continue_nodes = self._property_dict[node.child].continue_nodes
break_nodes = {node} if node.break_case else set()
self._property_dict[node] = SubtreeProperty(node, first_node, last_nodes, continue_nodes, break_nodes)

def visit_true_node(self, node: TrueNode) -> None:
pass

def visit_false_node(self, node: FalseNode) -> None:
pass

def visit_root_node(self, node: VirtualRootNode) -> None:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from typing import Optional

from decompiler.pipeline.commons.reaching_definitions import ReachingDefinitions
from decompiler.pipeline.controlflowanalysis.restructuring_commons.side_effect_handling.data_graph import DataGraph
from decompiler.structures.ast.syntaxtree import AbstractSyntaxTree
from decompiler.structures.graphs.cfg import ControlFlowGraph
from decompiler.task import DecompilerTask


class SideEffectHandler:
def __init__(self, ast: AbstractSyntaxTree, cfg: ControlFlowGraph, data_graph: DataGraph):
self._ast: AbstractSyntaxTree = ast
self._cfg: ControlFlowGraph = cfg
self._data_graph: DataGraph = data_graph

@classmethod
def resolve(cls, task: DecompilerTask) -> None:
# return
data_graph = DataGraph.generate_from_ast(task.syntax_tree)
side_effect_handler = cls(task.syntax_tree, task.graph, data_graph)
# DecoratedAST.from_ast(task.syntax_tree).export_plot("/home/eva/Projects/dewolf-decompiler/AST/ast.png")
from decompiler.util.decoration import DecoratedAST, DecoratedCFG

# DecoratedCFG.from_cfg(data_graph).export_plot("/home/eva/Projects/dewolf-decompiler/AST/cfg.png")
side_effect_handler.apply()

def apply(self):
reaching_definitions = ReachingDefinitions(self._data_graph)
all_defined_ssa_variables = set(var.ssa_name for node in self._data_graph.nodes for var in node.definitions)
# TODO: What about variables defined via phi-functions
# TODO: How about added variables without SSA-values?
# for logic_node in self._data_graph.get_logic_nodes():
# definitions: Dict[Variable, List[Instruction]] = dict()
# for instruction in reaching_definitions.reach_in_block(logic_node):
# for definition in instruction.definitions:
# definitions[definition] = definitions.get(definition, list()) + [instruction]
# for symbol in logic_node.logic_condition.get_symbols():
# pseudo_condition = self._ast.condition_map[symbol]
# for used_variable in pseudo_condition.requirements:
# if used_variable not in all_defined_ssa_variables:
# continue
# for def_instruction in definitions[used_variable]:
# definition = [def_var for def_var in def_instruction.definitions if def_var == used_variable][0]
# if used_variable.ssa_name != definition.ssa_name:
# raise "We have to handle side effects!"
2 changes: 1 addition & 1 deletion decompiler/util/decoration.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _format_label(node: BasicBlock) -> str:
if node.instructions is None:
instructions_left_aligned = " "
else:
instructions_left_aligned = "\n".join(map(str, node.instructions))
instructions_left_aligned = str(node)
return f"{node.name}.\n{instructions_left_aligned}"

def export_flowgraph(self) -> FlowGraph:
Expand Down
Loading