Skip to content

Commit

Permalink
updated the generator
Browse files Browse the repository at this point in the history
  • Loading branch information
imdeepmind committed Mar 21, 2020
1 parent 61ff1ec commit 9a6e1c3
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,12 @@ def ont_hot(sequences, nexts, batch_size):
y = np.zeros((batch_size, 128), dtype=np.bool)

for i, sequence in enumerate(sequences):
for t, char in enumerate(sequence):
if char < 0 or char > 128:
char = 97
x[i, t, char] = 1

if nexts[i] < 0 or nexts[i] > 128:
y[i, 97] = 1
else:
y[i, nexts[i]] = 1

return x, y
return np.array(sequences), y

def train_generator(batch_size):
global train_couter
Expand All @@ -48,7 +43,12 @@ def train_generator(batch_size):
sequences.append(temp)
nexts.append(ord(next))

yield ont_hot(sequences, nexts, batch_size)
x,y = ont_hot(sequences, nexts, batch_size)

assert x.shape == (batch_size, 40), "Invalid dimension for Input X"
assert y.shape == (batch_size, 128), "Invalid dimension for Output Y"

return x, y

def validation_generator(batch_size):
global validation_counter
Expand All @@ -73,7 +73,12 @@ def validation_generator(batch_size):
sequences.append(temp)
nexts.append(ord(next))

yield ont_hot(sequences, nexts, batch_size)
x,y = ont_hot(sequences, nexts, batch_size)

assert x.shape == (batch_size, 40), "Invalid dimension for Input X"
assert y.shape == (batch_size, 128), "Invalid dimension for Output Y"

return x, y

def test_generator(batch_size):
global test_counter
Expand All @@ -98,4 +103,9 @@ def test_generator(batch_size):
sequences.append(temp)
nexts.append(ord(next))

yield ont_hot(sequences, nexts, batch_size)
x,y = ont_hot(sequences, nexts, batch_size)

assert x.shape == (batch_size, 40), "Invalid dimension for Input X"
assert y.shape == (batch_size, 128), "Invalid dimension for Output Y"

return x, y

0 comments on commit 9a6e1c3

Please sign in to comment.