-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
931cad2
commit 0f70042
Showing
23 changed files
with
18,367 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .model import BERT | ||
from .dataloader import AsmLMTokenizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch | ||
import pickle | ||
from glob import glob | ||
from tqdm import tqdm | ||
from multiprocessing import Manager, Pool | ||
from .AsmTokenizer import AsmLMTokenizer | ||
|
||
class AsmLMDataset(torch.utils.data.Dataset): | ||
def __init__(self, dataset_path, vocab_paths, n_workers=48): | ||
self.n_workers = n_workers | ||
print('Initializing tokenizer...') | ||
self.tokenizer = AsmLMTokenizer(vocab_paths) | ||
|
||
print('Loading dataset...') | ||
self.dataset_path = dataset_path | ||
self.datas = Manager().list() | ||
self.load_data() | ||
|
||
def __getitem__(self, idx): # random visit | ||
return {key: torch.tensor(value) for key, value in self.datas[idx].items()} | ||
|
||
def __len__(self): | ||
return len(self.datas) | ||
|
||
def worker_load_data(self, pkl_file): | ||
pkl_datas = [] | ||
with open(pkl_file, 'rb') as f: | ||
load = pickle.load(f) | ||
for func_info in tqdm(load): | ||
pkl_datas.append(self.tokenizer.encode_func(func_info)) # when parallel, save to tmpfs | ||
|
||
self.datas.extend(pkl_datas) | ||
del pkl_datas | ||
|
||
def load_data(self): | ||
pool = Pool(processes=self.n_workers) | ||
input_list = [] | ||
for pkl_file in tqdm(glob('{}/*.pkl'.format(self.dataset_path))): | ||
input_list.append(pkl_file) | ||
pool.map(self.worker_load_data, input_list) | ||
pool.close() | ||
|
||
print(len(self.datas)) | ||
|
||
|
||
|
||
class AsmLMPreGenDataset(torch.utils.data.Dataset): | ||
def __init__(self, dataset_path, vocab_paths): | ||
|
||
print('Initializing tokenizer...') | ||
self.tokenizer = AsmLMTokenizer(vocab_paths) | ||
|
||
print('Loading dataset...') | ||
self.dataset_path = dataset_path | ||
self.datas = [] | ||
self.load_data() | ||
|
||
def __getitem__(self, idx): # random visit | ||
self.datas[idx] | ||
|
||
def __len__(self): | ||
return len(self.datas) | ||
|
||
def load_data(self): | ||
# TODO: load data | ||
pass | ||
|
||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
|
||
class AsmLMVocab(object): | ||
def __init__(self, vocab:list=[]): | ||
self.vocab = ['[PAD]', '[SEP]', '[CLS]', '[UNK]', '[MASK]'] + vocab # NOTE: [PAD] must be the first word to generate padding masks for attention. | ||
self.pad_id = 0 | ||
self.sep_id = 1 | ||
self.cls_id = 2 | ||
self.unk_id = 3 | ||
self.mask_id = 4 | ||
|
||
self.word2id = {token:idx for idx, token in enumerate(self.vocab)} | ||
self.id2word = {idx:token for idx, token in enumerate(self.vocab)} | ||
|
||
def __len__(self): | ||
return len(self.vocab) | ||
|
||
def __getitem__(self, idx): | ||
return self.vocab[idx] | ||
|
||
def get_id(self, token): | ||
if token in self.vocab: | ||
return self.word2id[token] | ||
else: | ||
return self.unk_id | ||
|
||
def save(self, vocab_path): | ||
with open(vocab_path, 'w') as f: | ||
f.write('\n'.join(self.vocab)) | ||
|
||
def load(self, vocab_path): | ||
with open(vocab_path, 'r') as f: | ||
self.vocab = f.read().split('\n') | ||
self.pad_id = self.vocab.index('[PAD]') | ||
self.sep_id = self.vocab.index('[SEP]') | ||
self.cls_id = self.vocab.index('[CLS]') | ||
self.unk_id = self.vocab.index('[UNK]') | ||
self.mask_id = self.vocab.index('[MASK]') | ||
|
||
self.word2id = {token:idx for idx, token in enumerate(self.vocab)} | ||
self.id2word = {idx:token for idx, token in enumerate(self.vocab)} | ||
|
||
# if __name__ == '__main__': | ||
# import os | ||
# dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
# vocab = AsmLMVocab() | ||
# vocab.load(os.path.join(dir_path, 'vocab.txt')) | ||
# print(vocab) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .AsmLMDataset import AsmLMDataset, AsmLMPreGenDataset | ||
from .AsmTokenizer import AsmLMTokenizer | ||
from .AsmVocab import AsmLMVocab |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
[PAD] | ||
[SEP] | ||
[CLS] | ||
[UNK] | ||
[MASK] | ||
MOD_AF,MOD_CF,MOD_SF,MOD_ZF | ||
MOD_AF,MOD_CF,MOD_SF,MOD_ZF,MOD_PF | ||
MOD_AF,MOD_CF,MOD_SF,MOD_ZF,MOD_PF,MOD_OF | ||
MOD_AF,MOD_CF,MOD_SF,MOD_ZF,MOD_PF,MOD_OF,MOD_TF,MOD_IF,MOD_DF,MOD_NT,MOD_RF | ||
MOD_AF,MOD_CF,MOD_SF,MOD_ZF,MOD_PF,MOD_OF,TEST_CF | ||
MOD_AF,MOD_SF,MOD_ZF,MOD_OF | ||
MOD_AF,MOD_SF,MOD_ZF,MOD_PF,MOD_OF | ||
MOD_AF,MOD_ZF,MOD_PF | ||
MOD_CF | ||
MOD_CF,MOD_OF,UNDEF_SF,UNDEF_ZF,UNDEF_PF,UNDEF_AF | ||
MOD_CF,MOD_SF,MOD_OF,UNDEF_ZF,UNDEF_PF,UNDEF_AF | ||
MOD_CF,MOD_SF,MOD_ZF,MOD_PF,MOD_OF,UNDEF_AF | ||
MOD_CF,MOD_SF,MOD_ZF,MOD_PF,UNDEF_OF,UNDEF_AF | ||
MOD_CF,MOD_SF,MOD_ZF,RESET_OF,UNDEF_PF,UNDEF_AF | ||
MOD_CF,MOD_SF,PRIOR_SF,PRIOR_PF | ||
MOD_CF,MOD_SF,RESET_OF,UNDEF_PF,UNDEF_AF | ||
MOD_CF,MOD_ZF,MOD_PF,RESET_OF,RESET_SF,RESET_AF | ||
MOD_CF,MOD_ZF,RESET_SF,RESET_AF,RESET_PF | ||
MOD_CF,MOD_ZF,UNDEF_OF,UNDEF_SF,UNDEF_PF,UNDEF_AF | ||
MOD_CF,PRIOR_SF,PRIOR_AF,PRIOR_PF | ||
MOD_CF,RESET_OF,RESET_SF,RESET_AF,RESET_PF | ||
MOD_CF,TEST_CF | ||
MOD_CF,UNDEF_OF | ||
MOD_CF,UNDEF_OF,UNDEF_SF,UNDEF_PF,UNDEF_AF | ||
MOD_OF | ||
MOD_OF,PRIOR_SF,PRIOR_AF,PRIOR_PF | ||
MOD_SF,MOD_ZF,MOD_PF,RESET_OF,RESET_CF,UNDEF_AF | ||
MOD_SF,MOD_ZF,RESET_OF,RESET_CF,UNDEF_PF,UNDEF_AF | ||
MOD_TF,MOD_IF,MOD_NT,MOD_RF | ||
MOD_ZF | ||
MOD_ZF,RESET_OF,RESET_CF,RESET_SF,RESET_AF,RESET_PF | ||
MOD_ZF,UNDEF_OF,UNDEF_SF,UNDEF_PF,UNDEF_AF | ||
MOD_ZF,UNDEF_OF,UNDEF_SF,UNDEF_PF,UNDEF_AF,UNDEF_CF | ||
NULL | ||
PRIOR_SF,PRIOR_ZF,PRIOR_AF,PRIOR_PF | ||
RESET_CF | ||
RESET_DF | ||
RESET_IF | ||
SET_CF | ||
SET_DF | ||
SET_IF | ||
TEST_CF | ||
TEST_DF | ||
TEST_OF | ||
TEST_OF,TEST_SF | ||
TEST_OF,TEST_SF,TEST_ZF | ||
TEST_PF | ||
TEST_SF | ||
TEST_ZF | ||
TEST_ZF,TEST_CF | ||
UNDEF_OF,UNDEF_SF,UNDEF_ZF,UNDEF_PF,UNDEF_AF,UNDEF_CF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
''' | ||
tokenize & save tokenized insns | ||
build vocab | ||
''' | ||
import os | ||
import re | ||
import pickle | ||
from glob import glob | ||
from tqdm import tqdm | ||
from concurrent.futures import ProcessPoolExecutor | ||
import numpy as np | ||
import argparse | ||
from AsmVocab import AsmLMVocab | ||
|
||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-i', '--input_dir', type=str, help='input dir to pkl generated by IDA') | ||
parser.add_argument('-o', '--output_dir', type=str, nargs='?', | ||
help='Output dir', default='./preprocess-outputs') | ||
parser.add_argument('-n', '--workers', type=int, nargs='?', | ||
help='Max Workers', default=48) | ||
args = parser.parse_args() | ||
|
||
|
||
|
||
def normalize_insn(asm): | ||
opcode, op_str = asm.split('\t') | ||
|
||
op_str = op_str.replace(' + ', '+') | ||
op_str = op_str.replace(' - ', '-') | ||
op_str = op_str.replace(' * ', '*') | ||
op_str = op_str.replace(' : ', ':') | ||
# op_str = op_str.replace(',', ' ,') | ||
op_str = re.sub('0x[0-9a-f]+', 'const', op_str) | ||
# print(f'{opcode} {op_str}') | ||
if op_str: | ||
opnd_strs = op_str.split(', ') | ||
else: | ||
opnd_strs = [] | ||
|
||
return opcode, opnd_strs | ||
|
||
|
||
def gen_vocab_per_file(pkl_file, output_dir): | ||
token_vocab = set() | ||
itype_vocab = set() | ||
opnd_type_vocab = set() | ||
reg_id_vocab = set() | ||
reg_r_vocab = set() | ||
reg_w_vocab = set() | ||
eflags_vocab = set() | ||
|
||
with open(pkl_file, 'rb') as f: | ||
load = pickle.load(f) | ||
for func_info in tqdm(load): | ||
for insn_addr,disasm,itype,op_info,eflags in func_info: | ||
if not disasm: | ||
continue | ||
|
||
opcode, opnd_strs = normalize_insn(disasm) | ||
|
||
opcode_tokens = opcode.split() | ||
insn_tokens = opcode_tokens + ' , '.join(opnd_strs).split() | ||
|
||
token_vocab.update(set(insn_tokens)) | ||
itype_vocab.add(str(itype)) | ||
eflags_vocab.add(str(eflags)) | ||
|
||
for i in range(len(opnd_strs)): | ||
opnd_type_vocab.add(str(op_info[i][0])) | ||
reg_id_vocab.add(str(op_info[i][1])) | ||
reg_r_vocab.add(str(op_info[i][2])) | ||
reg_w_vocab.add(str(op_info[i][3])) | ||
|
||
|
||
out_name = pkl_file.split(os.path.sep)[-1] | ||
with open(f'{output_dir}/{out_name}.token_vocab', 'w') as f: | ||
f.write(' '.join(list(token_vocab))) | ||
with open(f'{output_dir}/{out_name}.itype_vocab', 'w') as f: | ||
f.write(' '.join(list(itype_vocab))) | ||
with open(f'{output_dir}/{out_name}.opnd_type_vocab', 'w') as f: | ||
f.write(' '.join(list(opnd_type_vocab))) | ||
with open(f'{output_dir}/{out_name}.reg_id_vocab', 'w') as f: | ||
f.write(' '.join(list(reg_id_vocab))) | ||
with open(f'{output_dir}/{out_name}.reg_r_vocab', 'w') as f: | ||
f.write(' '.join(list(reg_r_vocab))) | ||
with open(f'{output_dir}/{out_name}.reg_w_vocab', 'w') as f: | ||
f.write(' '.join(list(reg_w_vocab))) | ||
with open(f'{output_dir}/{out_name}.eflags_vocab', 'w') as f: | ||
f.write(' '.join(list(eflags_vocab))) | ||
|
||
|
||
def merge_vocab(vocab_type): | ||
vocab = set() | ||
for vocab_file in tqdm(glob('{}/*.{}'.format(args.output_dir, vocab_type))): | ||
with open(vocab_file, 'r') as f: | ||
vocab.update(set(f.read().split())) | ||
|
||
print(len(vocab)) | ||
vocab = list(vocab) | ||
vocab.sort() | ||
vocab = AsmLMVocab(vocab) | ||
vocab.save(f'{dir_path}/{vocab_type}.txt') | ||
|
||
|
||
def gen_vocab(dataset_path): | ||
os.makedirs(args.output_dir, exist_ok=True) | ||
|
||
with ProcessPoolExecutor(max_workers=args.workers) as executor: | ||
for pkl_file in tqdm(glob('{}/*.pkl'.format(dataset_path))): | ||
executor.submit(gen_vocab_per_file, pkl_file, args.output_dir) | ||
|
||
# vocab = set() | ||
# for vocab_file in tqdm(glob('{}/*.vocab'.format(args.output_dir))): | ||
# with open(vocab_file, 'r') as f: | ||
# vocab.update(set(f.read().split())) | ||
|
||
# print(len(vocab)) | ||
# vocab = list(vocab) | ||
# vocab.sort() | ||
# vocab = AsmLMVocab(vocab) | ||
# vocab.save(f'{dir_path}/vocab.txt') | ||
|
||
merge_vocab('token_vocab') | ||
merge_vocab('itype_vocab') | ||
merge_vocab('opnd_type_vocab') | ||
merge_vocab('reg_id_vocab') | ||
merge_vocab('reg_r_vocab') | ||
merge_vocab('reg_w_vocab') | ||
merge_vocab('eflags_vocab') | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
gen_vocab(args.input_dir) | ||
os.system(f'rm -rf {args.output_dir}') | ||
# gen_vocab('dataset/outputs') |
Oops, something went wrong.