Skip to content

Commit

Permalink
增加单词侵袭模块
Browse files Browse the repository at this point in the history
  • Loading branch information
moon-hotel committed Jun 29, 2021
1 parent 33ab6d9 commit c7dbdd2
Showing 1 changed file with 35 additions and 9 deletions.
44 changes: 35 additions & 9 deletions data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from torchtext.vocab import Vocab
from torch.utils.data import DataLoader
import torch
import re


def my_tokenizer(s):
s = s.replace(',', " ,").replace(".", " .").replace("?", " ?")
s = s.replace("\\", " ").replace("(", "( ").replace(")", " )")
s = s.replace(',', " ,").replace(".", " .").replace("?", " ?").replace("!", " !")
return s.split()


Expand Down Expand Up @@ -88,15 +88,34 @@ def data_process(self, filepath):
:param filepath: 数据集路径
:return:
"""

def clean_str(string):
string = re.sub("[^A-Za-z\-\?\!\.\,]", " ", string).lower()
string = string.replace("that's", "that is")
string = string.replace("isn't", "is not")
string = string.replace("don't", "do not")
string = string.replace("did't", "did not")
string = string.replace("won't", "will not")
string = string.replace("can't", "can not")
string = string.replace("you're", "you are")
string = string.replace("they're", "they are")
string = string.replace("you'll", "you will")
string = string.replace("we'll", "we will")
string = string.replace("what's", "what is")
string = string.replace("i'm", "i am")
string = string.replace("let's", "let us")
return string

raw_iter = iter(open(filepath, encoding="utf8"))
data = []
max_len = 0
for raw in raw_iter:
line = raw.rstrip("\n").split('","')
s, l = line[-1][:-1], line[0][1:]
s = clean_str(s)
tensor_ = torch.tensor([self.vocab[token] for token in
self.tokenizer(s)], dtype=torch.long)
l = torch.tensor(int(l), dtype=torch.long)
l = torch.tensor(int(l) - 1, dtype=torch.long)
max_len = max(max_len, tensor_.size(0))
data.append((tensor_, l))
return data, max_len
Expand Down Expand Up @@ -128,9 +147,16 @@ def generate_batch(self, data_batch):
if __name__ == '__main__':
path = "./data/ag_news_csv/test.csv"
data_loader = LoadSentenceClassificationDataset(train_file_path=path,
tokenizer=my_tokenizer)
train_iter, test_iter = data_loader.load_train_val_test_data(path, path)
for sample, label in train_iter:
print(sample.shape) # [seq_len,batch_size]
print(label.shape) # [batch_size]
break
tokenizer=my_tokenizer,
max_sen_len=None)
data, max_len = data_loader.data_process(path)

# train_iter, test_iter = data_loader.load_train_val_test_data(path, path)
# i = 0
# print(len(train_iter))
# for sample, label in train_iter:
# print(sample.shape) # [seq_len,batch_size]
# print(label.shape) # [batch_size]
# if i == 5:
# break
# i += 1

0 comments on commit c7dbdd2

Please sign in to comment.