Skip to content

Commit

Permalink
save TfidfVectorizer model
Browse files Browse the repository at this point in the history
  • Loading branch information
orbxball committed May 22, 2017
1 parent 4fc9897 commit 4483031
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion hw5/tfidf_linearSVC.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,41 @@ def main():
### read training data & testing data
tags, texts, mlb = read_data(train_path)
test_texts = read_test(test_path)
all_corpus = texts + test_texts

### tokenize
all_corpus = texts + test_texts
vectorizer = CountVectorizer(stop_words='english', ngram_range=(1, 3), max_features=max_features_size)
transformer = TfidfTransformer()
transformer.fit(vectorizer.fit_transform(all_corpus))
sequences = transformer.transform(vectorizer.transform(texts))
test_data = transformer.transform(vectorizer.transform(test_texts))

vectorizer2 = TfidfVectorizer(stop_words='english', ngram_range=(1, 3), max_features=max_features_size, sublinear_tf=True)
vectorizer2.fit(all_corpus)
sequences2 = vectorizer2.transform(texts)
test_data2 = vectorizer2.transform(test_texts)

if is_valid:
(x_train, y_train),(x_valid, y_valid) = validate(sequences, tags, valid_size)
(x_train2, y_train2),(x_valid2, y_valid2) = validate(sequences2, tags, valid_size)
else:
x_train, y_train = sequences, tags
x_train2, y_train2 = sequences2, tags

linear_svc = OneVsRestClassifier(LinearSVC(C=5e-4, class_weight='balanced'))
linear_svc2 = OneVsRestClassifier(LinearSVC(C=5e-4, class_weight='balanced'))

### cross validation
scores = cross_val_score(linear_svc, x_train, y_train, cv=8, scoring='f1_samples', n_jobs=-1)
print(scores, scores.mean(), scores.std())
scores2 = cross_val_score(linear_svc2, x_train2, y_train2, cv=8, scoring='f1_samples', n_jobs=-1)
print(scores2, scores2.mean(), scores2.std())

### predict
linear_svc.fit(x_train, y_train)
predict = linear_svc.predict(test_data)
linear_svc2.fit(x_train2, y_train2)
predict2 = linear_svc2.predict(test_data2)
# print(mlb.classes_)

### save vectorizer + transformer + linearSVC
Expand All @@ -93,6 +105,11 @@ def main():
with open(linear_svc_name, 'wb') as f:
pickle.dump(linear_svc, f)

with open(vectorizer_name2, 'wb') as f:
pickle.dump(vectorizer2, f)
with open(linear_svc_name2, 'wb') as f:
pickle.dump(linear_svc2, f)

# Test data
ensure_dir(output_path)
result = []
Expand Down Expand Up @@ -126,5 +143,7 @@ def main():
vectorizer_name = 'vec'
transformer_name = 'trans'
linear_svc_name = 'linSVC'
vectorizer_name2 = 'vec2'
linear_svc_name2 = 'linSVC2'

main()

0 comments on commit 4483031

Please sign in to comment.