Skip to content

Commit

Permalink
added the generator
Browse files Browse the repository at this point in the history
  • Loading branch information
imdeepmind committed Mar 21, 2020
1 parent cc8e55a commit 6e8ef2e
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 121 deletions.
117 changes: 0 additions & 117 deletions generator.py
Original file line number Diff line number Diff line change
@@ -1,117 +0,0 @@
import sqlite3
import numpy as np

train_couter = 0
validation_counter = 0
test_counter = 0

def ont_hot(sequences, nexts, batch_size):
x = np.zeros((batch_size, 40, 128), dtype=np.bool)
y = np.zeros((batch_size, 128), dtype=np.bool)

for i, sequence in enumerate(sequences):
if nexts[i] < 0 or nexts[i] > 128:
y[i, 97] = 1
else:
y[i, nexts[i]] = 1

return np.array(sequences), y

def train_generator(batch_size):
global train_couter

while True:
conn = sqlite3.connect('data/code.db')
c = conn.cursor()

sql = f"SELECT sequence, next FROM code_sequences WHERE state = 'tr' LIMIT {batch_size} OFFSET {batch_size * train_couter}"

train_couter += 1

c.execute(sql)

rows = c.fetchall()

sequences = []
nexts = []

for sequence, next in rows:
temp = []
for char in sequence:
temp.append(ord(char))

sequences.append(temp)
nexts.append(ord(next))

x,y = ont_hot(sequences, nexts, batch_size)

assert x.shape == (batch_size, 40), "Invalid dimension for Input X"
assert y.shape == (batch_size, 128), "Invalid dimension for Output Y"

yield x, y

def validation_generator(batch_size):
global validation_counter

while True:
conn = sqlite3.connect('data/code.db')
c = conn.cursor()

sql = f"SELECT sequence, next FROM code_sequences WHERE state = 'tr' LIMIT {batch_size} OFFSET {batch_size * validation_counter}"

validation_counter += 1

c.execute(sql)

rows = c.fetchall()

sequences = []
nexts = []

for sequence, next in rows:
temp = []
for char in sequence:
temp.append(ord(char))

sequences.append(temp)
nexts.append(ord(next))

x,y = ont_hot(sequences, nexts, batch_size)

assert x.shape == (batch_size, 40), "Invalid dimension for Input X"
assert y.shape == (batch_size, 128), "Invalid dimension for Output Y"

yield x, y

def test_generator(batch_size):
global test_counter

while True:
conn = sqlite3.connect('data/code.db')
c = conn.cursor()

sql = f"SELECT sequence, next FROM code_sequences WHERE state = 'tr' LIMIT {batch_size} OFFSET {batch_size * test_counter}"

test_counter += 1

c.execute(sql)

rows = c.fetchall()

sequences = []
nexts = []

for sequence, next in rows:
temp = []
for char in sequence:
temp.append(ord(char))

sequences.append(temp)
nexts.append(ord(next))

x,y = ont_hot(sequences, nexts, batch_size)

assert x.shape == (batch_size, 40), "Invalid dimension for Input X"
assert y.shape == (batch_size, 128), "Invalid dimension for Output Y"

yield x, y
139 changes: 135 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random
import sqlite3
import numpy as np

from os import listdir
from os.path import isfile, join
Expand All @@ -8,6 +9,9 @@
class Model:
__files = []
__transcations = []
__train_counter = 0
__validation_counter = 0
__test_counter = 0

def __is_file(self, path, file):
"""
Expand Down Expand Up @@ -258,6 +262,136 @@ def build_sequences(self, force):
# Generating sequecne db
self.__build_sequence_db()

def __one_hot(self, sequences, nexts):
"""
One Hot Encoding for the labels and coverts everything into numpy array
Args:
sequences: a batch of sequences
nexts: a batch of next characters
"""
y = np.zeros((self.BATCH_SIZE, 128), dtype=np.bool)

for i, sequence in enumerate(sequences):
if nexts[i] < 0 or nexts[i] > 128:
y[i, 97] = 1
else:
y[i, nexts[i]] = 1

return np.array(sequences), y

def __train_generator(self):
"""
Train Generator that generates a batch of data for the model
Args:
"""
while True:
conn = sqlite3.connect(self.DATA_FOLDER + "/" + self.SEQUENCE_DB)
c = conn.cursor()

sql = f"SELECT sequence, next FROM code_sequences WHERE state = 'tr' LIMIT {self.BATCH_SIZE} OFFSET {self.BATCH_SIZE * self.__train_counter}"

self.__train_counter += 1

c.execute(sql)

rows = c.fetchall()

sequences = []
nexts = []

for sequence, next in rows:
temp = []
for char in sequence:
temp.append(ord(char))

sequences.append(temp)
nexts.append(ord(next))

x,y = self.__one_hot(sequences, nexts, self.BATCH_SIZE)

assert x.shape == (self.BATCH_SIZE, 40), "Invalid dimension for Input X"
assert y.shape == (self.BATCH_SIZE, 128), "Invalid dimension for Output Y"

yield x, y

def validation_generator(self):
"""
Validation Generator that generates a batch of data for the model
Args:
"""
while True:
conn = sqlite3.connect(self.DATA_FOLDER + "/" + self.SEQUENCE_DB)
c = conn.cursor()

sql = f"SELECT sequence, next FROM code_sequences WHERE state = 'va' LIMIT {self.BATCH_SIZE} OFFSET {self.BATCH_SIZE * self.__validation_counter}"

self.__validation_counter += 1

c.execute(sql)

rows = c.fetchall()

sequences = []
nexts = []

for sequence, next in rows:
temp = []
for char in sequence:
temp.append(ord(char))

sequences.append(temp)
nexts.append(ord(next))

x,y = self.__one_hot(sequences, nexts, self.BATCH_SIZE)

assert x.shape == (self.BATCH_SIZE, 40), "Invalid dimension for Input X"
assert y.shape == (self.BATCH_SIZE, 128), "Invalid dimension for Output Y"

yield x, y

def test_generator(self):
"""
Test Generator that generates a batch of data for the model
Args:
"""
while True:
conn = sqlite3.connect(self.DATA_FOLDER + "/" + self.SEQUENCE_DB)
c = conn.cursor()

sql = f"SELECT sequence, next FROM code_sequences WHERE state = 'te' LIMIT {self.BATCH_SIZE} OFFSET {self.BATCH_SIZE * self.__test_counter}"

self.__test_counter += 1

c.execute(sql)

rows = c.fetchall()

sequences = []
nexts = []

for sequence, next in rows:
temp = []
for char in sequence:
temp.append(ord(char))

sequences.append(temp)
nexts.append(ord(next))

x,y = self.__one_hot(sequences, nexts, self.BATCH_SIZE)

assert x.shape == (self.BATCH_SIZE, 40), "Invalid dimension for Input X"
assert y.shape == (self.BATCH_SIZE, 128), "Invalid dimension for Output Y"

yield x, y

def __init__(self,
DATA_FOLDER='data',
CODE_FILE_LIST='code_list.txt',
Expand All @@ -268,7 +402,4 @@ def __init__(self,
self.DATA_FOLDER = DATA_FOLDER
self.CODE_FILE_LIST = CODE_FILE_LIST
self.SEQUENCE_DB = SEQUENCE_DB
self.SEQ_LENGTH = SEQ_LENGTH



self.SEQ_LENGTH = SEQ_LENGTH

0 comments on commit 6e8ef2e

Please sign in to comment.