Skip to content

Commit

Permalink
add checkpoint & csv log
Browse files Browse the repository at this point in the history
  • Loading branch information
orbxball committed Apr 24, 2017
1 parent 920a077 commit 4d69edd
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions hw3/train_cnn.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
#!/usr/bin/env python
import sys, os
import numpy as np
import keras
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, Dropout, LeakyReLU
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from keras.layers import BatchNormalization
from keras import losses
from keras import optimizers
from keras.callbacks import CSVLogger, ModelCheckpoint

# Parameter
height = width = 48
num_classes = 7
input_shape = (height, width, 1)
batch_size = 128
epochs = 100
zoom_range = 0.05
model_name = 'pre5.h5'
epochs = 300
zoom_range = 0.2
model_name = 'pre6.h5'
isValid = 1

# Read the train data
Expand All @@ -30,33 +31,34 @@
Y = data[:, 0]
X /= 255
Y = Y.reshape(Y.shape[0], 1)
Y = keras.utils.to_categorical(Y, num_classes)
Y = to_categorical(Y, num_classes)

# Change data into CNN format
X = X.reshape(X.shape[0], height, width, 1)

# Split the data
if isValid:
valid_num = 3000
permu = np.random.permutation(X.shape[0])
X_train, Y_train = X[permu[:-valid_num]], Y[permu[:-valid_num]]
X_valid, Y_valid = X[permu[-valid_num:]], Y[permu[-valid_num:]]
X_train, Y_train = X[:-valid_num], Y[:-valid_num]
X_valid, Y_valid = X[-valid_num:], Y[-valid_num:]

else:
X_train, Y_train = X, Y

# Construct the model
model = Sequential()

model.add(Conv2D(64, (3, 3), padding='same', input_shape=input_shape))
model.add(LeakyReLU(alpha=0.03))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
model.add(Dropout(0.25))
model.add(Dropout(0.2))

model.add(Conv2D(128, (3, 3), padding='same'))
model.add(LeakyReLU(alpha=0.03))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
model.add(Dropout(0.3))
model.add(Dropout(0.25))

model.add(Conv2D(256, (3, 3), padding='same'))
model.add(LeakyReLU(alpha=0.03))
Expand All @@ -72,7 +74,7 @@

model.add(Flatten())

model.add(Dense(128, activation='relu'))
model.add(Dense(256, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(512, activation='relu'))
Expand All @@ -89,24 +91,33 @@
metrics=['accuracy'])

# Image PreProcessing
train_gen = ImageDataGenerator(rotation_range=10,
width_shift_range=0.05,
height_shift_range=0.05,
shear_range=0.05,
train_gen = ImageDataGenerator(rotation_range=25,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.1,
zoom_range=[1-zoom_range, 1+zoom_range],
horizontal_flip=True)
train_gen.fit(X_train)

# Callbacks
callbacks = []
modelcheckpoint = ModelCheckpoint('ckpt/weights.{epoch:03d}-{val_acc:.5f}.h5', monitor='val_acc', save_best_only=True)
callbacks.append(modelcheckpoint)
csv_logger = CSVLogger('cnn_log.csv', separator=',', append=False)
callbacks.append(csv_logger)

# Fit the model
if isValid:
model.fit_generator(train_gen.flow(X_train, Y_train, batch_size=batch_size),
steps_per_epoch=10*X_train.shape[0]//batch_size,
epochs=epochs,
callbacks=callbacks,
validation_data=(X_valid, Y_valid))
else:
model.fit_generator(train_gen.flow(X_train, Y_train, batch_size=batch_size),
steps_per_epoch=10*X_train.shape[0]//batch_size,
epochs=epochs)
epochs=epochs,
callbacks=callbacks)

# Save model
model.save(model_name)

0 comments on commit 4d69edd

Please sign in to comment.