-
Notifications
You must be signed in to change notification settings - Fork 1
/
evo_single.py
83 lines (67 loc) · 2.55 KB
/
evo_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import cma
import numpy as np
import pickle
import os
import argparse
from logger import *
map_func = [
# add corresponding map function here
np.log,
np.log,
np.ceil,
]
def conf2gene(confs, map_funcs):
genes = []
assert len(confs) == len(map_funcs)
for conf, func in zip(confs, map_funcs):
genes.append(func(conf))
return genes
def get_arguments():
parser = argparse.ArgumentParser(description=None)
parser.add_argument("--checkpoint", default="checkpoints/es_G%s.pkl",
help="path to checkpoint files")
parser.add_argument("--gene", default="generation_%s/genes/%s.gene",
help="path to backup es files")
parser.add_argument("--init", default="gene.init", help="init gene file")
parser.add_argument("--scr", default="", help="path to score file")
parser.add_argument("--pop", default=30, type=int, help="population")
parser.add_argument("--n-pop", type=int, help="current population index")
parser.add_argument("--n-gen", type=int, help="current generation index")
args = parser.parse_args()
return args
def evolution(args):
logging.info("======================================================")
logging.info("(Generation %d) Start generating genes..." % args.n_gen)
logging.info(args)
logging.info("======================================================")
if args.n_gen == 0:
with open(args.init) as f:
init_vec = f.readlines()
init_vec = list(map(lambda x: float(x.strip()), init_vec))
init_vec = conf2gene(init_vec, map_func)
es = cma.CMAEvolutionStrategy(init_vec, 0.1, {
'seed': 1,
'popsize': args.pop,
})
X = es.ask()
else:
# load previous checkpoint
with open(args.checkpoint % str(args.n_gen - 1).zfill(2), "rb") as es_file:
es = pickle.load(es_file)
X = es.ask()
# open score file
with open(args.scr) as f:
# read score file
scores = f.readlines()
Y = list(map(lambda x: -float(x.split("\t")[-1].strip()), scores))
es.tell(X, Y)
# save current checkpoint
with open(args.checkpoint % str(args.n_gen).zfill(2), "wb") as es_file:
pickle.dump(es, es_file)
for gene_idx in range(args.pop):
with open(args.gene % str(gene_idx).zfill(2), "w") as gene:
gene.write("\n".join(X[gene_idx].astype(np.str)))
if __name__ == "__main__":
args = get_arguments()
evolution(args)
logging.info("Generated %s genes." % args.pop)