Skip to content

Commit

Permalink
complete train
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZhuXing committed Mar 4, 2021
1 parent 1b21653 commit 7881fb1
Showing 1 changed file with 48 additions and 8 deletions.
56 changes: 48 additions & 8 deletions source/04_generalized_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ def build_model(img_width, img_height):

# data generator definition
class DataGenerator(keras.utils.Sequence):
def __init__(self, img_data, label_data, img_real, label_real, batch_size = 32, shuffle = True):
def __init__(self, img_data, label_data, img_width, img_height, batch_size = 32, shuffle = True):
'Initialization'
self.img_data = img_data
self.label_data = label_data
self.img_real = img_real
self.label_real = label_real
self.img_width = img_width
self.img_height = img_height
self.batch_size = batch_size
self.shuffle = shuffle
self.on_epoch_end()
Expand All @@ -194,22 +194,27 @@ def __getitem__(self, index):
index = int(np.floor(index / 2))
img1_batch = self.img_data[index * self.batch_size : (index + 1) * self.batch_size]
label1_batch = self.label_data[index * self.batch_size : (index + 1) * self.batch_size]
img2_batch = np.empty((self.batch_size, 128, 128, 1), dtype = np.float32)
img2_batch = np.empty((self.batch_size, self.img_width, self.img_height, 1), dtype = np.float32)
label2_batch = np.zeros((self.batch_size, 1), dtype = np.float32)

for i, idx in enumerate(label1_batch):
if random.random() > 0.5:
# put matched image
img2_batch[i] = self.img_real[idx]
while True:
match_idx = random.choice(range(len(self.label_data)))
if (self.label_data[match_idx] == idx):
break

img2_batch[i] = self.img_data[match_idx]
label2_batch[i] = 1.
else:
# put unmatched image
while True:
unmatch_idx = random.choice(list(self.label_real))
if (unmatch_idx != idx):
unmatch_idx = random.choice(range(len(self.label_data)))
if (self.label_data[unmatch_idx] != idx):
break

img2_batch[i] = self.img_real[unmatch_idx]
img2_batch[i] = self.img_data[unmatch_idx]
label2_batch[i] = 0.

index = real_idx
Expand All @@ -228,6 +233,9 @@ def main(args):
image_width = args.image_width
image_height = args.image_height
dataset_path = args.dataset_path
check_point = args.check_point
save_model = args.save_model
save_model_path = args.save_model_path

# check parameter
if ((image_width > IMAGE_WIDTH_LIMIT) or (image_height > IMAGE_HEIGHT_LIMIT)):
Expand All @@ -244,6 +252,7 @@ def main(args):
print('\tImage Width: ', image_width)
print('\tImage Height: ', image_height)
print('\tDataset Path: ', dataset_path)
print('\tCheckpoint Path: ', check_point)

# initialize data preparer
train_data_pre = Prepare_Data(image_width, image_height, dataset_path)
Expand All @@ -254,6 +263,31 @@ def main(args):
print(img_train.shape, label_train.shape)
print(img_val.shape, label_val.shape)

# prepare data generator
train_gen = DataGenerator(img_train, label_train, image_width, image_height, shuffle = True)
val_gen = DataGenerator(img_val, label_val, image_width, image_height, shuffle = True)

# configure checkpoint data
checkpoint_path = check_point + 'cp_' + str(image_width) + '_' + str(image_height) + '.ckpt'

# create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_path, save_weights_only = True, verbose = 1)

# prepare model
model = build_model(image_width, image_height)
if (os.path.exists(checkpoint_path + '.index')):
print('continue training')
model.load_weights(checkpoint_path)

# check only save
if (save_model == 1):
model_path = save_model_path + 'fp' + str(image_width) + '_' + str(image_height) + '.h5'
model.save(model_path)
return

# training model
model.fit(train_gen, epochs = 100, validation_data = val_gen, callbacks = [cp_callback])


# argument parser
def parse_arguments(argv):
Expand All @@ -266,6 +300,12 @@ def parse_arguments(argv):
help = 'Process image height', default = 128)
parser.add_argument('--dataset_path', type = str,
help = 'Path to fingerprint image dataset', default = '../../dataset/original/')
parser.add_argument('--check_point', type = str,
help = 'Path to model checkpoint', default = '../../model/checkpoint/')
parser.add_argument('--save_model', type = int,
help = 'Only save model from checkpoint', default = 0)
parser.add_argument('--save_model_path', type = str,
help = 'Path to model', default = '../../model/result/')

return parser.parse_args(argv)

Expand Down

0 comments on commit 7881fb1

Please sign in to comment.