Skip to content

Commit

Permalink
TfidfVectorizer -> CounterVectorizer + TfidfTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
orbxball committed May 22, 2017
1 parent 782488b commit e8a76f5
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions hw5/tfidf_linearSVC.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,31 @@ def main():
tags, texts, mlb = read_data(train_path)
test_texts = read_test(test_path)

### Tokenize
# vectorizer = TfidfVectorizer(stop_words='english')
vectorizer = TfidfVectorizer(stop_words='english', ngram_range=(1, 3), max_features=40000, sublinear_tf=True)
### tokenize
all_corpus = texts + test_texts
vectorizer.fit(all_corpus)
sequences = vectorizer.transform(texts)
test_data = vectorizer.transform(test_texts)
# vectorizer = TfidfVectorizer(stop_words='english', ngram_range=(1, 3), max_features=max_fearues_size, sublinear_tf=True)
# vectorizer.fit(all_corpus)
# sequences = vectorizer.transform(texts)
# test_data = vectorizer.transform(test_texts)
vectorizer = CountVectorizer(stop_words='english', ngram_range=(1, 3), max_features=max_features_size)
transformer = TfidfTransformer()
transformer.fit(vectorizer.fit_transform(all_corpus))
sequences = transformer.transform(vectorizer.transform(texts))
test_data = transformer.transform(vectorizer.transform(test_texts))

if is_valid:
(x_train, y_train),(x_valid, y_valid) = validate(sequences, tags, valid_size)
else:
x_train, y_train = sequences, test_data
x_train, y_train = sequences, tags

linear_svc = OneVsRestClassifier(LinearSVC(C=5e-4, class_weight='balanced'))

### cross validation
scores = cross_val_score(linear_svc, x_train, y_train, cv=8, scoring='f1_samples', n_jobs=-1)

print(scores, scores.mean(), scores.std())

### predict
linear_svc.fit(x_train, y_train)
# y_train_predict = linear_svc.predict(x_train)
# y_valid_predict = linear_svc.predict(x_valid)
# print(f1_score(y_valid, y_valid_predict, average='micro'))
predict = linear_svc.predict(test_data)
# print(mlb.classes_)

Expand Down Expand Up @@ -115,6 +117,6 @@ def main():
output_path = args.output
is_valid = args.valid
valid_size = -400
max_vocab = 60000
max_features_size = 40000

main()
main()

0 comments on commit e8a76f5

Please sign in to comment.