-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
119 lines (102 loc) · 3.38 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!(CUDA_VISIBLE_DEVICES=-1)
from param import *
from data_iterator import MyDataset, MyIterator
from model_utils import make_model
from my_decode import greedy_decode
from torchtext import data, datasets
import torch
import pandas as pd
import numpy as np
import os
INS_SPLIT = '<nop>'
BLANK_WORD = '<blank>'
BOS_WORD = '<s>'
EOS_WORD = '</s>'
tokenize = lambda x: x.split(' ')
SRC = data.Field(sequential=True, tokenize=tokenize, pad_token=BLANK_WORD, lower=True)
TGT = data.Field(sequential=True, tokenize=tokenize, init_token=BOS_WORD, eos_token=EOS_WORD, pad_token=BLANK_WORD, lower=True)
train = MyDataset(datafile=TRAIN_FILE, asm_field=SRC, ast_field=TGT)
test = MyDataset(datafile=TEST_FILE, asm_field=SRC, ast_field=TGT)
SRC.build_vocab(train)
TGT.build_vocab(train)
src_pad_idx = SRC.vocab.stoi["<blank>"]
tgt_pad_idx = TGT.vocab.stoi["<blank>"]
split_idx = SRC.vocab.stoi['<nop>']
print("Loading model...")
model = make_model(len(SRC.vocab),
len(TGT.vocab),
src_token_len=SRC_TOKEN_LEN,
token=SRC_TOKEN,
ins_pad=split_idx,
pad_idx=src_pad_idx,
N=LAYER_NUM,
d_model=D_MODEL,
h=H)
model.load_state_dict(torch.load('model-1.pt', map_location=torch.device('cpu')))
test_iter = MyIterator(test,
batch_size=BATCH_SIZE,
repeat=False,
sort_key=lambda x: x.src.count(INS_SPLIT),
train=False)
field = ["asm_length", "ast_length", "asm", "target", "translation"]
count=0
for i, batch in enumerate(test_iter):
src = batch.src.transpose(0, 1)[:1]
shape = src.shape
tmp_src_mask = (src != src_pad_idx).unsqueeze(-2).reshape([shape[0], int(shape[1]/SRC_TOKEN), SRC_TOKEN]).sum(dim=-1)
mask = (tmp_src_mask != 0).unsqueeze(-2)
out = greedy_decode(model, src, mask,
max_len=MAX_LEN,
start_symbol=TGT.vocab.stoi["<s>"])
print("Translation:", end="\t")
trans = []
for j in range(1, out.size(1)):
# print(out[0,i])
sym = TGT.vocab.itos[out[0, j]]
if sym == "</s>":
trans.append(sym)
print("</s>")
break
print(sym, end=" ")
trans.append(sym)
print()
print("Target:", end="\t")
target = []
for j in range(1, batch.trg.size(0)):
sym = TGT.vocab.itos[batch.trg.data[j, 0]]
if sym == "</s>":
target.append(sym)
break
print(sym, end=" ")
target.append(sym)
print()
print()
asm = []
for index in src[0]:
if index == src_pad_idx:
break
if index == split_idx:
continue
asm.append(SRC.vocab.itos[index])
print(asm)
dt = [[int(len(asm)/8), len(target), ' '.join(asm), ' '.join(target), ' '.join(trans)]]
data = pd.DataFrame(columns=field, data=dt)
if not os.path.exists('translation.csv'):
data.to_csv('translation.csv', mode='a+', encoding='utf-8', header=True)
else:
data.to_csv('translation.csv', mode='a', encoding='utf-8', header=False)
# break
count+=1
sent = []
src_tmp = src[0].reshape([int(src[0].shape[0]/8), 8])
for ins in src_tmp:
tmp_sent = []
for opcode in ins:
opc = SRC.vocab.itos[opcode]
if opc == '<nop>':
break
tmp_sent.append(opc)
sent.append(''.join(tmp_sent))
from visualization import draw, visualization
visualization(model, trans, ' '.join(sent))
break