-
Notifications
You must be signed in to change notification settings - Fork 380
/
data_loader.py
47 lines (39 loc) · 2.13 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""
@author : Hyunwoong
@when : 2019-10-29
@homepage : https://github.com/gusdnd852
"""
from torchtext.legacy.data import Field, BucketIterator
from torchtext.legacy.datasets.translation import Multi30k
class DataLoader:
source: Field = None
target: Field = None
def __init__(self, ext, tokenize_en, tokenize_de, init_token, eos_token):
self.ext = ext
self.tokenize_en = tokenize_en
self.tokenize_de = tokenize_de
self.init_token = init_token
self.eos_token = eos_token
print('dataset initializing start')
def make_dataset(self):
if self.ext == ('.de', '.en'):
self.source = Field(tokenize=self.tokenize_de, init_token=self.init_token, eos_token=self.eos_token,
lower=True, batch_first=True)
self.target = Field(tokenize=self.tokenize_en, init_token=self.init_token, eos_token=self.eos_token,
lower=True, batch_first=True)
elif self.ext == ('.en', '.de'):
self.source = Field(tokenize=self.tokenize_en, init_token=self.init_token, eos_token=self.eos_token,
lower=True, batch_first=True)
self.target = Field(tokenize=self.tokenize_de, init_token=self.init_token, eos_token=self.eos_token,
lower=True, batch_first=True)
train_data, valid_data, test_data = Multi30k.splits(exts=self.ext, fields=(self.source, self.target))
return train_data, valid_data, test_data
def build_vocab(self, train_data, min_freq):
self.source.build_vocab(train_data, min_freq=min_freq)
self.target.build_vocab(train_data, min_freq=min_freq)
def make_iter(self, train, validate, test, batch_size, device):
train_iterator, valid_iterator, test_iterator = BucketIterator.splits((train, validate, test),
batch_size=batch_size,
device=device)
print('dataset initializing done')
return train_iterator, valid_iterator, test_iterator