This repository has been archived by the owner on Nov 15, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 38
/
data_utils.py
141 lines (117 loc) · 4.4 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import codecs
import glob
import json
import random
import numpy as np
class Vocabulary(object):
def __init__(self):
self._token_to_id = {}
self._token_to_count = {}
self._id_to_token = []
self._num_tokens = 0
self._s_id = None
self._unk_id = None
@property
def num_tokens(self):
return self._num_tokens
@property
def unk(self):
return "<UNK>"
@property
def unk_id(self):
return self._unk_id
@property
def s(self):
return "<S>"
@property
def s_id(self):
return self._s_id
def add(self, token, count):
self._token_to_id[token] = self._num_tokens
self._token_to_count[token] = count
self._id_to_token.append(token)
self._num_tokens += 1
def finalize(self):
self._s_id = self.get_id(self.s)
self._unk_id = self.get_id(self.unk)
def get_id(self, token):
return self._token_to_id.get(token, self.unk_id)
def get_token(self, id_):
return self._id_to_token[id_]
@staticmethod
def from_file(filename):
vocab = Vocabulary()
with codecs.open(filename, "r", "utf-8") as f:
for line in f:
word, count = line.strip().split()
vocab.add(word, int(count))
vocab.finalize()
return vocab
class Dataset(object):
def __init__(self, vocab, file_pattern, deterministic=False):
self._vocab = vocab
self._file_pattern = file_pattern
self._deterministic = deterministic
def _parse_sentence(self, line):
s_id = self._vocab.s_id
return [s_id] + [self._vocab.get_id(word) for word in line.strip().split()] + [s_id]
def _parse_file(self, file_name):
print("Processing file: %s" % file_name)
with codecs.open(file_name, "r", "utf-8") as f:
lines = [line.strip() for line in f]
if not self._deterministic:
random.shuffle(lines)
print("Finished processing!")
for line in lines:
yield self._parse_sentence(line)
def _sentence_stream(self, file_stream):
for file_name in file_stream:
for sentence in self._parse_file(file_name):
yield sentence
def _iterate(self, sentences, batch_size, num_steps):
streams = [None] * batch_size
x = np.zeros([batch_size, num_steps], np.int32)
y = np.zeros([batch_size, num_steps], np.int32)
#w = np.zeros([batch_size, num_steps], np.uint8)
do_iters = True
while do_iters:
x[:] = 0
y[:] = 0
#w[:] = 0
for i in range(batch_size):
tokens_filled = 0
try:
while tokens_filled < num_steps:
if streams[i] is None or len(streams[i]) <= 1:
streams[i] = next(sentences)
num_tokens = min(len(streams[i]) - 1, num_steps - tokens_filled)
x[i, tokens_filled:tokens_filled+num_tokens] = streams[i][:num_tokens]
y[i, tokens_filled:tokens_filled + num_tokens] = streams[i][1:num_tokens+1]
#w[i, tokens_filled:tokens_filled + num_tokens] = 1
streams[i] = streams[i][num_tokens:]
tokens_filled += num_tokens
except StopIteration:
do_iters = False
#if not np.any(w):
# return
#yield x, y, w
yield x, y
def iterate_once(self, batch_size, num_steps):
def file_stream():
for file_name in glob.glob(self._file_pattern):
yield file_name
for value in self._iterate(self._sentence_stream(file_stream()), batch_size, num_steps):
yield value
def iterate_forever(self, batch_size, num_steps):
def file_stream():
epoch_num = 0
while True:
file_patterns = glob.glob(self._file_pattern)
if not self._deterministic:
random.shuffle(file_patterns)
for file_name in file_patterns:
yield file_name
print('Done epoch: %s' % epoch_num)
epoch_num += 1
for value in self._iterate(self._sentence_stream(file_stream()), batch_size, num_steps):
yield value