-
Notifications
You must be signed in to change notification settings - Fork 1
/
toy_nmt.py
39 lines (28 loc) · 1.17 KB
/
toy_nmt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import os
from logger import *
import argparse
# Script to generate fake nmt score
def get_arguments():
parser = argparse.ArgumentParser(description=None)
parser.add_argument("--model-desc", default="",
help="path to model description files's folder")
parser.add_argument("--trg", default="",
help="target file for output score file")
parser.add_argument("--n-gen", type=int, help="current generation index")
parser.add_argument("--n-model", type=int,
help="current model description file index")
args = parser.parse_args()
return args
n_data = 100
template = "%s BLEU = %.5f, %.5f/%.5f/%.5f/%.5f (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"
if __name__ == "__main__":
args = get_arguments()
scores = np.random.rand(n_data, 5)
logging.info("loading file: %s" % (args.model_desc % args.n_model))
# cur_path = path % (str(n_gen).zfill(2), str(n_model).zfill(2))
# os.makedirs(cur_path)
with open(args.trg, "w+") as f:
for idx, score in enumerate(scores):
f.writelines(template % (str(idx).zfill(4), *score))
f.flush()