Skip to content

Commit

Permalink
src
Browse files Browse the repository at this point in the history
  • Loading branch information
X-XG committed Dec 18, 2021
1 parent e343d82 commit 2ade7ca
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
28 changes: 21 additions & 7 deletions exp2/w2v_transE/tester.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pickle import FALSE, TRUE
import numpy as np
import codecs
import operator
Expand Down Expand Up @@ -56,12 +57,25 @@ def __init__(self,entity_dict,relation_dict,test_triple,train_triple,isFit = Tru
self.hits5 = 0
self.mean_rank = 0

def rank(self, output_path):
hits = 0
step = 1
f = open(output_path + 'tail_predict.txt', 'w')
def rank(self, output_path, re_continue, re_hits):
if re_continue:
f = open(output_path + 'tail_predict.txt', 'r')
step = len(f.readlines())
hits = re_hits
f.close()
f = open(output_path + 'tail_predict.txt', 'a')
past = step
else:
hits = 0
step = 0
f = open(output_path + 'tail_predict.txt', 'w')
past = 0


for triple in self.test_triple:
if past > 0:
past -= 1
continue
rank_tail_dict = {}

for entity in self.entity_dict.keys():
Expand Down Expand Up @@ -100,11 +114,11 @@ def rank(self, output_path):
first_hit = False

step += 1
if step % 10 == 0:
if step % 20 == 0:
print("step ", step, ", hits ",hits, ', rate: ',hits/step)
print()

self.hits5 = hits / (2*len(self.test_triple))
self.hits5 = hits / len(self.test_triple)


if __name__ == '__main__':
Expand All @@ -116,7 +130,7 @@ def rank(self, output_path):


test = Test(entity_dict,relation_dict,test_triple,train_triple,isFit=True)
test.rank('./output/')
test.rank('./output/',re_continue=False, re_hits=0)
print("entity hits@5: ", test.hits5)


19 changes: 11 additions & 8 deletions exp2/w2v_transE/transE.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def distanceL1(h,r,t):
return np.sum(np.fabs(h+r-t))

class TransE:
def __init__(self,entity_set, relation_set, triple_list, ent_rel_sim_path, use_w2v = False, nbatch = 400,
def __init__(self,entity_set, relation_set, triple_list, ent_rel_sim_path, use_w2v = False,mu = 1, nbatch = 400,
embedding_dim=100, learning_rate=0.01, margin=1,L1=True):
self.embedding_dim = embedding_dim
self.learning_rate = learning_rate
Expand All @@ -50,6 +50,7 @@ def __init__(self,entity_set, relation_set, triple_list, ent_rel_sim_path, use_w
self.relation = relation_set
self.triple_list = triple_list
self.L1=L1
self.mu = mu

self.nbatch = nbatch
self.use_w2v = use_w2v
Expand Down Expand Up @@ -176,8 +177,10 @@ def update_embeddings(self, Tbatch):
t_corrupt = self.entity[corrupted_triple[1]]

if self.use_w2v:
w2v_correct = self.ent_rel_sim[int(triple[0])][int(triple[2])] * self.ent_rel_sim[int(triple[1])][int(triple[2])]
w2v_corrupt = self.ent_rel_sim[int(corrupted_triple[0])][int(triple[2])] * self.ent_rel_sim[int(corrupted_triple[1])][int(triple[2])]
w2v_correct = self.ent_rel_sim[int(triple[0])][int(triple[2])] \
* self.ent_rel_sim[int(triple[1])][int(triple[2])]
w2v_corrupt = self.ent_rel_sim[int(corrupted_triple[0])][int(triple[2])] \
* self.ent_rel_sim[int(corrupted_triple[1])][int(triple[2])]
w2v_revise = w2v_correct - w2v_corrupt

if self.L1:
Expand Down Expand Up @@ -240,7 +243,7 @@ def update_embeddings(self, Tbatch):
self.relation = copy_relation

def hinge_loss(self,dist_correct,dist_corrupt, w2v_revise):
return max(0,dist_correct-dist_corrupt+self.margin + w2v_revise)
return max(0,dist_correct-dist_corrupt+self.margin + self.mu*w2v_revise)


if __name__=='__main__':
Expand All @@ -249,7 +252,7 @@ def hinge_loss(self,dist_correct,dist_corrupt, w2v_revise):
print("load file...")
print("Complete load. entity : %d , relation : %d , triple : %d" % (len(entity_set),len(relation_set),len(triple_list)))

transE = TransE(entity_set, relation_set, triple_list, './output/ent_rel_sim.npy', use_w2v = False, nbatch=100, embedding_dim=200, learning_rate=0.01, margin=5,L1=True)
# transE.emb_initialize()
transE.reload('./temp/')
transE.train(epochs=0, temp_path='./temp/')
transE = TransE(entity_set, relation_set, triple_list, './output/ent_rel_sim.npy', use_w2v = True, mu=12,nbatch=100, embedding_dim=200, learning_rate=0.001, margin=4.0,L1=True)
transE.emb_initialize()
# transE.reload('./temp/')
transE.train(epochs=100, temp_path='./temp/')

0 comments on commit 2ade7ca

Please sign in to comment.