Skip to content

Commit

Permalink
run some experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
summer1278 committed Dec 19, 2017
1 parent 279bd6d commit e047524
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 2 deletions.
6 changes: 6 additions & 0 deletions memn2n/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def load_task(data_dir, task_id, only_supporting=False):
test_file = [f for f in files if s in f and 'test' in f][0]
train_data = get_stories(train_file, only_supporting)
test_data = get_stories(test_file, only_supporting)
# print len(train_data),len(test_data)
return train_data, test_data


Expand Down Expand Up @@ -97,6 +98,9 @@ def vectorize_data(data, word_idx, sentence_size, memory_size):
"""
S, Q, A = [], [], []
for story, query, answer in data:
# print story
# print query
# print answer
ss = []
for i, sentence in enumerate(story, 1):
ls = max(0, sentence_size - len(sentence))
Expand All @@ -117,10 +121,12 @@ def vectorize_data(data, word_idx, sentence_size, memory_size):

lq = max(0, sentence_size - len(query))
q = [word_idx[w] for w in query] + [0] * lq
# print q

y = np.zeros(len(word_idx) + 1) # 0 is reserved for nil word
for a in answer:
y[word_idx[a]] = 1

S.append(ss); Q.append(q); A.append(y)
print len(S),len(Q),len(A)
return np.array(S), np.array(Q), np.array(A)
3 changes: 3 additions & 0 deletions memn2n/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __init__(self, dataset_dir, task_id=1, memory_size=50, train=True):
if train:
story, query, answer = vectorize_data(train_data, self.word_idx,
self.sentence_size, self.memory_size)
print 'story',story.shape
print 'query[0]',torch.LongTensor(query)[0].shape
print 'answer',answer.shape
else:
story, query, answer = vectorize_data(test_data, self.word_idx,
self.sentence_size, self.memory_size)
Expand Down
5 changes: 4 additions & 1 deletion memn2n/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def position_encoding(sentence_size, embedding_dim):
encoding = 1.0 + 4.0 * encoding / embedding_dim / sentence_size
# Make position encoding of time words identity to avoid modifying them
encoding[:, -1] = 1.0
print 'sent_size',sentence_size,'embed_dim',embedding_dim
print 'encoding -1:',np.transpose(encoding).shape
return np.transpose(encoding)

class AttrProxy(object):
Expand Down Expand Up @@ -58,9 +60,10 @@ def forward(self, story, query):
# print 'story size',story_size
u = list()
query_embed = self.C[0](query)
# print 'query_embed',query_embed.size()
# print 'query',query.size()
# weired way to perform reduce_dot
encoding = self.encoding.unsqueeze(0).expand_as(query_embed)
# print 'encoding',encoding.shape
u.append(torch.sum(query_embed*encoding, 1))

for hop in range(self.max_hops):
Expand Down
6 changes: 5 additions & 1 deletion memn2n/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def evaluate(self, data="test"):
story = story.cuda()
query = query.cuda()
answer = answer.cuda()

pred_prob = self.mem_n2n(story, query)[1]
pred = pred_prob.data.max(1)[1] # max func return (max, argmax)
correct += pred.eq(answer.data).cpu().sum()
Expand All @@ -85,7 +85,11 @@ def evaluate(self, data="test"):
def _train_single_epoch(self, epoch):
config = self.config
num_steps_per_epoch = len(self.train_loader)
print self.train_loader[0]
for step, (story, query, answer) in enumerate(self.train_loader):
print 'story',story.shape
print 'query',query.shape
print 'answer',answer.shape
story = Variable(story)
query = Variable(query)
answer = Variable(answer)
Expand Down

0 comments on commit e047524

Please sign in to comment.