-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist.py
143 lines (136 loc) · 4.11 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# -*- coding: utf-8 -*-#
"""
File : mnist.py
Description :
Author : 赵金朋
Modify Time : 2019/5/27 15:07
"""
#-*- coding:utf-8 -*-
from bp import *
from datetime import datetime
from os import *
# 数据加载器基类
class Loader():
def __init__(self, path, count):
'''
初始化加载器
path: 数据文件路径
count: 文件中的样本个数
'''
self.path = path
self.count = count
def get_file_content(self):
'''
读取文件内容
'''
f = open(self.path, 'rb')
content = f.read()
f.close()
return content
# 图像数据加载器
class ImageLoader(Loader):
def get_picture(self, content, index):
'''
内部函数,从文件中获取图像
'''
start = index * 28 * 28 + 16
picture = []
for i in range(28):
picture.append([])
for j in range(28):
picture[i].append(
int(content[start + i * 28 + j]))
return picture
def get_one_sample(self, picture):
'''
内部函数,将图像转化为样本的输入向量
'''
sample = []
for i in range(28):
for j in range(28):
sample.append(picture[i][j])
return sample
def load(self):
'''
加载数据文件,获得全部样本的输入向量
'''
content = self.get_file_content()
data_set = []
for index in range(self.count):
data_set.append(
self.get_one_sample(
self.get_picture(content, index)))
return data_set
# 标签数据加载器
class LabelLoader(Loader):
def load(self):
'''
加载数据文件,获得全部样本的标签向量
'''
content = self.get_file_content()
labels = []
for index in range(self.count):
labels.append(self.norm(content[index + 8]))
return labels
def norm(self, label):
'''
内部函数,将一个值转换为10维标签向量
'''
label_vec = []
label_value = int(label)
for i in range(10):
if i == label_value:
label_vec.append(0.9)
else:
label_vec.append(0.1)
return label_vec
def get_training_data_set():
'''
获得训练数据集
'''
image_loader = ImageLoader('train-images.idx3-ubyte.gz', 60000)
label_loader = LabelLoader('train-labels.idx1-ubyte.gz', 60000)
return image_loader.load(), label_loader.load()
def get_test_data_set():
'''
获得测试数据集
'''
image_loader = ImageLoader('t10k-images.idx3-ubyte.gz', 10000)
label_loader = LabelLoader('t10k-labels.idx1-ubyte.gz', 10000)
return image_loader.load(), label_loader.load()
def get_result(vec):
max_value_index = 0
max_value = 0
for i in range(len(vec)):
if vec[i] > max_value:
max_value = vec[i]
max_value_index = i
return max_value_index
def evaluate(network, test_data_set, test_labels):
error = 0
total = len(test_data_set)
for i in range(total):
label = get_result(test_labels[i])
predict = get_result(network.predict(test_data_set[i]))
if label != predict:
error += 1
return float(error) / float(total)
def train_and_evaluate():
last_error_ratio = 1.0
epoch = 0
train_data_set, train_labels = get_training_data_set()
test_data_set, test_labels = get_test_data_set()
network = Network([784, 30, 10])
while True:
epoch += 1
network.train(train_labels, train_data_set, 0.3, 1)
print('%s epoch %d finished' % (now(), epoch))
if epoch % 10 == 0:
error_ratio = evaluate(network, test_data_set, test_labels)
print('%s after epoch %d, error ratio is %f' % (now(), epoch, error_ratio))
if error_ratio > last_error_ratio:
break
else:
last_error_ratio = error_ratio
if __name__ == '__main__':
train_and_evaluate()