-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
126 lines (97 loc) · 3.72 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from datasets import load_dataset
import pandas as pd
import glob
from sklearn.model_selection import train_test_split
class RawData():
'''
Raw data class
Each function for getting a data split should return a list of strings.
'''
def get_train_data(self) -> list[str]:
raise NotImplementedError
def get_validation_data(self) -> list[str]:
raise NotImplementedError
def get_test_data(self) -> list[str]:
raise NotImplementedError
def get_data(self):
return {
'train': self.get_train_data(),
'validation': self.get_validation_data(),
'test': self.get_test_data()
}
class TweetData(RawData):
def __init__(self):
self.tweets = load_dataset("tweet_eval", "sentiment")
def get_train_data(self):
return self.tweets['train']['text']
def get_validation_data(self):
return self.tweets['validation']['text']
def get_test_data(self):
return self.tweets['test']['text']
class TweetDataMini(RawData):
def __init__(self):
self.tweets = load_dataset("tweet_eval", "sentiment")
def get_train_data(self):
return self.tweets['train']['text'][:1024]
def get_validation_data(self):
return self.tweets['validation']['text'][:1024]
def get_test_data(self):
return self.tweets['test']['text'][:1024]
class NewsData(RawData):
def __init__(self):
raw_dataset = load_dataset('ag_news')
train_and_val = raw_dataset['train'].train_test_split(test_size=0.04)
self.news = {
'train': train_and_val['train'],
'validation': train_and_val['test'],
'test': raw_dataset['test']
}
def get_train_data(self):
return self.news['train']['text']
def get_validation_data(self):
return self.news['validation']['text']
def get_test_data(self):
return self.news['test']['text']
class CommentData(RawData):
def __init__(self):
raw_dataset = load_dataset('ag_comments')
train_and_val = raw_dataset['train'].train_test_split(test_size=0.04)
self.comments = {
'train': train_and_val['train'],
'validation': train_and_val['test'],
'test': raw_dataset['test']
}
def get_train_data(self):
return self.comments['train']['text']
def get_validation_data(self):
return self.comments['validation']['text']
def get_test_data(self):
return self.comments['test']['text']
class RedditData(RawData):
def __init__(self, test_size=0.2, val_size=0.1, limit_n=20000):
# https://www.kaggle.com/datasets/mexwell/reddit-comment-and-thread
# there are datasets of comments from different kinds of topics on Reddit
# firstly combine all the data and then split the data
files = glob.glob('data/*.csv')
combined_df = pd.concat([pd.read_csv(f) for f in files], ignore_index=True)
text = combined_df['0'].tolist()
text = [x for x in text if isinstance(x, str)]
text = text[:limit_n]
train_val, test = train_test_split(text, test_size=test_size, random_state=42)
train, val = train_test_split(train_val, test_size=val_size/(1-test_size), random_state=42)
self.train_data = train
self.validation_data = val
self.test_data = test
def get_train_data(self):
return self.train_data
def get_validation_data(self):
return self.validation_data
def get_test_data(self):
return self.test_data
dataset_classes: dict[str, RawData] = {
'TweetData': TweetData,
'TweetDataMini': TweetDataMini,
'NewsData': NewsData,
'CommentData': CommentData,
'RedditData': RedditData
}