Skip to content

Commit

Permalink
best version 1.2: four model
Browse files Browse the repository at this point in the history
  • Loading branch information
orbxball committed May 23, 2017
1 parent c97f342 commit d80f7c8
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions hw5/best.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,30 @@ def main():
predict3[predict3 >= threshold] = 1


### RNN 2
model = load_model(model_name2, custom_objects={'fmeasure': fmeasure})
predict4 = model.predict(sequences_test)
predict4[predict4 < threshold] = 0
predict4[predict4 >= threshold] = 1


### Voting
pred = predict + predict2 + predict3
pred[pred < 1.5] = 0
pred[pred >= 1.5] = 1
pred = predict + predict2 + predict3 + predict4
check = predict3 + predict4
pred[pred < 2] = 0
pred[pred > 2] = 1
x = np.argwhere(pred == 2)
for i, j in x:
if check[i, j] == 2:
pred[i, j] = 0
else:
pred[i, j] = 1


# Test data
ensure_dir(output_path)
result = []
mlb_backup = mlb.inverse_transform(predict3)
mlb_backup = mlb.inverse_transform(predict4)
for i, categories in enumerate(mlb.inverse_transform(pred)):
ret = []
if len(categories) == 0:
Expand Down Expand Up @@ -169,7 +183,8 @@ def main():
vectorizer_name2 = os.path.join(base_dir, 'vec2')
linear_svc_name2 = os.path.join(base_dir, 'linSVC2')
tokenizer_name = os.path.join(base_dir, 'word_index')
model_name = os.path.join(base_dir, 'model-5.h5')
model_name = os.path.join(base_dir, 'model-6.h5')
model_name2 = os.path.join(base_dir, 'model-5.h5')
threshold = 0.4

main()

0 comments on commit d80f7c8

Please sign in to comment.