-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
382574f
commit 144f50f
Showing
9 changed files
with
209 additions
and
114 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,63 @@ | ||
#!/usr/bin/python3 | ||
# this script implements evaluation metrics | ||
|
||
import sys | ||
import tensorflow as tf | ||
from keras import backend as K | ||
|
||
|
||
def strict_accuracy(act, pred): | ||
def strict_accuracy_K(act, pred): | ||
""" | ||
Computes accuracy for each batch without factoring in padding symbols | ||
Keras metric that computes the accuracy of tagged sentences for each batch | ||
Predictions with a categorical vector of [1 0 0 ... 0] are not factored in | ||
Inputs: | ||
- act: array of actual categorical vectors | ||
- pred: array of predicted categorical vectors | ||
Outputs: | ||
- accuracy score | ||
""" | ||
# values of actual classes | ||
# numerical values of the actual classes | ||
act_argm = K.argmax(act, axis=-1) | ||
# values of predicted classes | ||
# numerical values of the predicted classes | ||
pred_argm = K.argmax(pred, axis=-1) | ||
# determines where the tags are incorrect (1) or not (0) | ||
# determines where the classes are incorrect or not | ||
incorrect = K.cast(K.not_equal(act_argm, pred_argm), dtype='float32') | ||
# determines where the tags are correct (1) or not (0) | ||
# determines where the classes are correct or not | ||
correct = K.cast(K.equal(act_argm, pred_argm), dtype='float32') | ||
# determines where the tag is a padding tag (1) or not (0) | ||
# determines where the classes are ignored or not | ||
padding = K.cast(K.equal(act_argm, 0), dtype='float32') | ||
# subtract padding from correct predictions and check equality to 1 | ||
corr_preds = K.sum(K.cast(K.equal(correct - padding, 1), dtype='float32')) | ||
incorr_preds = K.sum(K.cast(K.equal(incorrect - padding, 1), dtype='float32')) | ||
total = corr_preds + incorr_preds | ||
total_preds = corr_preds + incorr_preds | ||
# actual accuracy without padding | ||
accuracy = corr_preds / total | ||
accuracy = corr_preds / total_preds | ||
return accuracy | ||
|
||
|
||
def strict_accuracy_N(act, pred, ignore_class=0): | ||
""" | ||
Computes the accuracy of an array of tagged sentences | ||
Actual values which match `ignore_class` are not factored in | ||
Inputs: | ||
- act: array of actual numerical vectors | ||
- pred: array of predicted numerical vectors | ||
- ignore_class: numerical value to be ignored | ||
Outputs: | ||
- accuracy score | ||
""" | ||
# number of correct predictions | ||
corr_preds = 0 | ||
# number of predictions | ||
total_preds = 0 | ||
# compute values via iterating over sentences | ||
for sent in zip(act, pred): | ||
act_classes = sent[0] | ||
pred_classes = sent[1] | ||
for t in range(len(act_classes)): | ||
if act_classes[t] != ignore_class: | ||
total_preds += 1 | ||
if pred_classes[t] == act_classes[t]: | ||
corr_preds += 1 | ||
# actual accuracy without padding | ||
return corr_preds / total_preds | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
|
||
import sys | ||
import numpy as np | ||
from copy import deepcopy | ||
from collections import OrderedDict | ||
from models.nn import get_model | ||
from models.metrics import strict_accuracy_N | ||
|
||
|
||
def grid_search_params(base_args, cv_samples, X, y, padding_y, num_tags, max_slen, num_words, wemb_dim, wemb_matrix, max_wlen, num_chars, cemb_dim, cemb_matrix): | ||
args = deepcopy(base_args) | ||
# define model parameters possible values | ||
grid_params = OrderedDict() | ||
#grid_params['epochs'] = [20, 30, 40] | ||
#grid_params['batch_size'] = [512, 1024, 2048] | ||
#grid_params['optimizer'] = ['rmsprop', 'adam', 'nadam'] | ||
#grid_params['dropout'] = [0.1, 0.2, 0.3] | ||
#grid_params['model_size'] = [200, 300, 400] | ||
#grid_params['num_layers'] = [1, 2, 3] | ||
|
||
grid_params['epochs'] = [2, 3] | ||
grid_params['batch_size'] = [512] | ||
grid_params['optimizer'] = ['adam'] | ||
grid_params['dropout'] = [0.1] | ||
grid_params['model_size'] = [200] | ||
grid_params['num_layers'] = [1] | ||
|
||
# define parameter combinations | ||
grid_space = np.array(np.meshgrid(grid_params['epochs'], | ||
grid_params['batch_size'], | ||
grid_params['optimizer'], | ||
grid_params['dropout'], | ||
grid_params['model_size'], | ||
grid_params['num_layers'])).T.reshape(-1,len(grid_params)) | ||
|
||
print([grid_params[x] for x in grid_params.keys()]) | ||
print(grid_space) | ||
|
||
# perform 3-fold cross validation for each parameter combinations | ||
print('[INFO] Grid-search will optimize the following hyper-parameters:', list(grid_params.keys())) | ||
cv_samples = 3 | ||
|
||
X_cv_train = X | ||
if args.use_words and args.use_chars: | ||
block_size = len(X[0]) // cv_samples | ||
X_cv_dev = [[],[]] | ||
else: | ||
block_size = len(X) // cv_samples | ||
X_cv_dev = [] | ||
y_cv_dev = [] | ||
y_cv_train = y | ||
|
||
best_acc = 0 | ||
best_params = None | ||
|
||
# test each parameter combination | ||
for i in range(grid_space.shape[0]): | ||
|
||
current_acc = 0 | ||
cell = grid_space[i] | ||
args.optimizer = cell[2] | ||
args.dropout = float(cell[3]) | ||
args.model_size = int(cell[4]) | ||
args.num_layers = int(cell[5]) | ||
|
||
model = get_model(args, num_tags, | ||
max_slen, num_words, wemb_dim, wemb_matrix, | ||
max_wlen, num_chars, cemb_dim, cemb_matrix) | ||
|
||
print('[INFO] ' + str(cv_samples) + '-fold cross-validation using the hyper-parameter set', cell) | ||
for _ in range(cv_samples): | ||
# rotate the block used for testing | ||
if args.use_words and args.use_chars: | ||
if len(X_cv_dev[0]): | ||
X_cv_train = [np.append(X_cv_train[0], X_cv_dev[0], axis=0), np.append(X_cv_train[1], X_cv_dev[1], axis=0)] | ||
X_cv_dev = [X_cv_train[0][:block_size], X_cv_train[1][:block_size]] | ||
X_cv_train = [X_cv_train[0][block_size:], X_cv_train[1][block_size:]] | ||
|
||
else: | ||
if len(X_cv_dev): | ||
X_cv_train = np.append(X_cv_train, X_cv_dev, axis=0) | ||
X_cv_dev = X_cv_train[:block_size] | ||
X_cv_train = X_cv_train[block_size:] | ||
|
||
if len(y_cv_dev): | ||
y_cv_train = np.append(y_cv_train, y_cv_dev, axis=0) | ||
y_cv_dev = y_cv_train[:block_size] | ||
y_cv_train = y_cv_train[block_size:] | ||
|
||
# fit the model | ||
history = model.fit(X_cv_train, np.array(y_cv_train), batch_size = int(cell[1]), epochs=int(cell[0]), validation_split=0.0, verbose=0) | ||
|
||
# obtain accuracy on validation | ||
|
||
# predictions on the test set | ||
p_cv_dev = model.predict(X_cv_dev, verbose=0) | ||
p_cv_dev = np.argmax(p_cv_dev, axis=-1) + 1 | ||
true_cv_dev = np.argmax(y_cv_dev, axis=-1) + 1 | ||
current_acc += strict_accuracy_N(true_cv_dev, p_cv_dev, 1) | ||
|
||
# average | ||
current_acc = current_acc / cv_samples | ||
print('[INFO] The accuracy with the given hyper-parameters is', current_acc) | ||
if current_acc > best_acc: | ||
best_acc = current_acc | ||
best_params = cell | ||
|
||
print('[INFO] The best set of hyper-parameters found is', best_params) | ||
|
||
# plug in args | ||
args.epochs = int(best_params[0]) | ||
args.batch_size = int(best_params[1]) | ||
args.optimizer = best_params[2] | ||
args.dropout = float(best_params[3]) | ||
args.model_size = int(best_params[4]) | ||
args.num_layers = int(best_params[5]) | ||
return args | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.