-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.f90
180 lines (148 loc) · 6.35 KB
/
main.f90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
!!!#############################################################################
!!! Code written by Ned Thaddeus Taylor
!!! Code part of the ARTEMIS group (Hepplestone research group)
!!! Think Hepplestone, think HRG
!!!#############################################################################
program mnist_test
#ifdef _OPENMP
use omp_lib
#endif
use athena
use constants_mnist, only: real12
use read_mnist, only: read_mnist_db
use inputs
implicit none
type(network_type) :: network
!! data loading and preoprocessing
real(real12), allocatable, dimension(:,:,:,:) :: input_images, test_images
integer, allocatable, dimension(:) :: labels, test_labels
integer, allocatable, dimension(:,:) :: input_labels
character(1024) :: train_file, test_file
!! neural network size and shape variables
integer, parameter :: num_classes = 10 ! Number of output classes
integer :: image_size
integer :: input_channels
!! training loop variables
integer :: num_samples, num_samples_test
integer :: i, itmp1
#ifdef _OPENMP
integer, allocatable, dimension(:) :: label_slice
real(real12), allocatable, dimension(:,:,:,:) :: image_slice
#endif
!!!-----------------------------------------------------------------------------
!!! initialise global variables
!!!-----------------------------------------------------------------------------
call set_global_vars(param_file="example/mnist_drop/test_job.in")
#ifdef _OPENMP
write(*,*) "number of threads:", num_threads
call omp_set_num_threads(num_threads)
#endif
!!!-----------------------------------------------------------------------------
!!! read training dataset
!!!-----------------------------------------------------------------------------
train_file = trim(data_dir)//'/MNIST_train.txt'
call read_mnist_db(train_file,input_images, labels, &
maxval(cv_kernel_size), image_size, padding_method)
input_channels = size(input_images, 3)
num_samples = size(input_images, 4)
!!!-----------------------------------------------------------------------------
!!! read testing dataset
!!!-----------------------------------------------------------------------------
test_file = trim(data_dir)//'/MNIST_test.txt'
call read_mnist_db(test_file,test_images, test_labels, &
maxval(cv_kernel_size), itmp1, padding_method)
num_samples_test = size(test_images, 4)
!!!-----------------------------------------------------------------------------
!!! initialise random seed
!!!-----------------------------------------------------------------------------
call random_setup(seed, restart=.false.)
!!!-----------------------------------------------------------------------------
!!! shuffle dataset
!!!-----------------------------------------------------------------------------
if(shuffle_dataset)then
write(6,*) "Shuffling training dataset..."
call shuffle(input_images, labels, 4, seed)
write(6,*) "Training dataset shuffled"
if(verbosity.eq.-1)then
write(6,*) "Check fort.11 and fort.12 to ensure data shuffling &
&executed properly"
do i=1,batch_size*2
write(11,*) input_images(:,:,:,i)
end do
write(12,*) labels
end if
end if
!!!-----------------------------------------------------------------------------
!!! initialise convolutional and pooling layers
!!!-----------------------------------------------------------------------------
if(restart)then
write(*,*) "Reading network from file..."
call network%read(file=input_file)
write(*,*) "Reading finished"
else
write(6,*) "Initialising CNN..."
call network%add(conv2d_layer_type( &
input_shape = [image_size,image_size,input_channels], &
num_filters = cv_num_filters, kernel_size = 3, stride = 1, &
padding=padding_method, &
calc_input_gradients = .false., &
activation_function = "relu" &
))
call network%add(dropblock2d_layer_type( &
rate = 0.25, block_size = 5))
call network%add(maxpool2d_layer_type(&
pool_size = 2, stride = 2))
call network%add(full_layer_type( &
num_outputs = 100, &
activation_function = "relu", &
kernel_initialiser = "he_uniform", &
bias_initialiser = "he_uniform" &
))
call network%add(full_layer_type( &
num_outputs = 10,&
activation_function = "softmax", &
kernel_initialiser = "glorot_uniform", &
bias_initialiser = "glorot_uniform" &
))
end if
call network%compile(optimiser=optimiser, &
loss_method=loss_method, metrics=metric_dict, &
batch_size = batch_size, verbose = verbosity)
write(*,*) "NUMBER OF LAYERS",network%num_layers
!!!-----------------------------------------------------------------------------
!!! training loop
!!! ... loops over num_epoch number of epochs
!!! ... i.e. it trains on the same datapoints num_epoch times
!!!-----------------------------------------------------------------------------
allocate(input_labels(num_classes,num_samples))
input_labels = 0
do i=1,num_samples
input_labels(labels(i),i) = 1
end do
write(6,*) "Starting training..."
call network%train(input_images, input_labels, num_epochs, batch_size, &
plateau_threshold = plateau_threshold, &
shuffle_batches = shuffle_dataset, &
batch_print_step = batch_print_step, verbose = verbosity)
!!!-----------------------------------------------------------------------------
!!! print weights and biases of CNN to file
!!!-----------------------------------------------------------------------------
write(*,*) "Writing network to file..."
call network%print(file=output_file)
write(*,*) "Writing finished"
!!!-----------------------------------------------------------------------------
!!! testing loop
!!!-----------------------------------------------------------------------------
deallocate(input_labels)
allocate(input_labels(num_classes,num_samples_test))
input_labels = 0
do i=1,num_samples_test
input_labels(test_labels(i),i) = 1
end do
write(*,*) "Starting testing..."
call network%test(test_images,input_labels)
write(*,*) "Testing finished"
write(6,'("Overall accuracy=",F0.5)') network%accuracy
write(6,'("Overall loss=",F0.5)') network%loss
end program mnist_test
!!!#############################################################################