Skip to content

Commit

Permalink
openke_transE
Browse files Browse the repository at this point in the history
  • Loading branch information
ljy0ustc committed Dec 14, 2021
1 parent 7c1cd0c commit e6e0f0b
Show file tree
Hide file tree
Showing 38 changed files with 2,519 additions and 0 deletions.
185 changes: 185 additions & 0 deletions exp2/openke_transE/DataProcessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import json
trans2res_idmap={}
entity_dict={}
entity_id = 0
entities = ['0']
with open(".//benchmarks//ljy//entity_with_text.txt", "r") as f_read:
for line in f_read.readlines():
entity = line.split()
entity_dict[entity[0]]=entity_id
trans2res_idmap[entity_id]=entity[0]
entities.append(str(entity[0]))
entities.append("\t")
entities.append(str(entity_id))
entities.append("\n")
entity_id +=1



relation_dict={}
relation_id = 0
relations = ['0']
with open(".//benchmarks//ljy//relation_with_text.txt", "r") as f_read:
for line in f_read.readlines():
relation = line.split()
relation_dict[relation[0]]=relation_id
relations.append(str(relation[0]))
relations.append("\t")
relations.append(str(relation_id))
relations.append("\n")
relation_id +=1



l = ['0']
f_write = open(".//benchmarks//ljy//train2id.txt", 'w')
train_id = 0
with open(".//benchmarks//ljy//train.txt", "r") as f_read:
for line in f_read.readlines():
line = line.split()
if line[0] in entity_dict:
l.append(str(entity_dict[line[0]]))
else:
entity_dict[line[0]] = entity_id
trans2res_idmap[entity_id]=line[0]
entities.append(str(line[0]))
entities.append("\t")
entities.append(str(entity_id))
entities.append("\n")
l.append(str(entity_dict[line[0]]))
entity_id +=1
l.append(' ')
if line[2] in entity_dict:
l.append(str(entity_dict[line[2]]))
else:
entity_dict[line[2]] = entity_id
trans2res_idmap[entity_id]=line[2]
entities.append(str(line[2]))
entities.append("\t")
entities.append(str(entity_id))
entities.append("\n")
l.append(str(entity_dict[line[2]]))
entity_id +=1
l.append(' ')
if line[1] in relation_dict:
l.append(str(relation_dict[line[1]]))
else:
relation_dict[line[1]] =relation_id
relations.append(str(line[1]))
relations.append("\t")
relations.append(str(relation_id))
relations.append("\n")
l.append(str(relation_dict[line[1]]))
relation_id +=1
l.append('\n')
train_id +=1
l[0] = str(train_id) + '\n'
f_write.writelines(l)
f_write.close()

f_write = open(".//benchmarks//ljy//valid2id.txt", 'w')
l = ['0']
dev_id = 0
with open(".//benchmarks//ljy//dev.txt", "r") as f_read:
for line in f_read.readlines():
line = line.split()
if line[0] in entity_dict:
l.append(str(entity_dict[line[0]]))
else:
entity_dict[line[0]] = entity_id
trans2res_idmap[entity_id]=line[0]
entities.append(str(line[0]))
entities.append("\t")
entities.append(str(entity_id))
entities.append("\n")
l.append(str(entity_dict[line[0]]))
entity_id +=1
l.append(' ')
if line[2] in entity_dict:
l.append(str(entity_dict[line[2]]))
else:
entity_dict[line[2]] = entity_id
trans2res_idmap[entity_id]=line[2]
entities.append(str(line[2]))
entities.append("\t")
entities.append(str(entity_id))
entities.append("\n")
l.append(str(entity_dict[line[2]]))
entity_id +=1
l.append(' ')
if line[1] in relation_dict:
l.append(str(relation_dict[line[1]]))
else:
relation_dict[line[1]] =relation_id
relations.append(str(line[1]))
relations.append("\t")
relations.append(str(relation_id))
relations.append("\n")
l.append(str(relation_dict[line[1]]))
relation_id +=1
l.append('\n')
dev_id +=1
l[0] = str(dev_id) + '\n'
f_write.writelines(l)
f_write.close()

f_write = open(".//benchmarks//ljy//test2id.txt", 'w')
l = ['0']
test_id = 0
with open(".//benchmarks//ljy//test.txt", "r") as f_read:
for line in f_read.readlines():
line = line.split()
if line[0] in entity_dict:
l.append(str(entity_dict[line[0]]))
else:
entity_dict[line[0]] = entity_id
trans2res_idmap[entity_id]=line[0]
entities.append(str(line[0]))
entities.append("\t")
entities.append(str(entity_id))
entities.append("\n")
l.append(str(entity_dict[line[0]]))
entity_id +=1
l.append(' ')
if line[1] in entity_dict:
l.append(str(entity_dict[line[1]]))
else:
entity_dict[line[1]] = entity_id
trans2res_idmap[entity_id]=line[1]
entities.append(str(line[1]))
entities.append("\t")
entities.append(str(entity_id))
entities.append("\n")
l.append(str(entity_dict[line[1]]))
entity_id +=1
l.append(' ')
if line[2] in relation_dict:
l.append(str(relation_dict[line[2]]))
else:
relation_dict[line[2]] =relation_id
relations.append(str(line[2]))
relations.append("\t")
relations.append(str(relation_id))
relations.append("\n")
l.append(str(relation_dict[line[2]]))
relation_id +=1
l.append('\n')
test_id +=1
l[0] = str(test_id) + '\n'
f_write.writelines(l)
f_write.close()

f_write = open(".//benchmarks//ljy//entity2id.txt", 'w')
entities[0] = str(entity_id) +'\n'
f_write.writelines(entities)
f_write.close()

f_write = open(".//benchmarks//ljy//relation2id.txt", 'w')
relations[0] = str(relation_id) +'\n'
f_write.writelines(relations)
f_write.close()

jsonMap = json.dumps(trans2res_idmap)
fileObject = open('.//result//map.json','w')
fileObject.write(jsonMap)
fileObject.close()
137 changes: 137 additions & 0 deletions exp2/openke_transE/base/Base.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#include "Setting.h"
#include "Random.h"
#include "Reader.h"
#include "Corrupt.h"
#include "Test.h"
#include "Valid.h"
#include <cstdlib>
#include <pthread.h>

extern "C"
void setInPath(char *path);

extern "C"
void setOutPath(char *path);

extern "C"
void setWorkThreads(INT threads);

extern "C"
void setBern(INT con);

extern "C"
INT getWorkThreads();

extern "C"
INT getEntityTotal();

extern "C"
INT getRelationTotal();

extern "C"
INT getTripleTotal();

extern "C"
INT getTrainTotal();

extern "C"
INT getTestTotal();

extern "C"
INT getValidTotal();

extern "C"
void randReset();

extern "C"
void importTrainFiles();

struct Parameter {
INT id;
INT *batch_h;
INT *batch_t;
INT *batch_r;
REAL *batch_y;
INT batchSize;
INT negRate;
INT negRelRate;
};

void* getBatch(void* con) {
Parameter *para = (Parameter *)(con);
INT id = para -> id;
INT *batch_h = para -> batch_h;
INT *batch_t = para -> batch_t;
INT *batch_r = para -> batch_r;
REAL *batch_y = para -> batch_y;
INT batchSize = para -> batchSize;
INT negRate = para -> negRate;
INT negRelRate = para -> negRelRate;
INT lef, rig;
if (batchSize % workThreads == 0) {
lef = id * (batchSize / workThreads);
rig = (id + 1) * (batchSize / workThreads);
} else {
lef = id * (batchSize / workThreads + 1);
rig = (id + 1) * (batchSize / workThreads + 1);
if (rig > batchSize) rig = batchSize;
}
REAL prob = 500;
for (INT batch = lef; batch < rig; batch++) {
INT i = rand_max(id, trainTotal);
batch_h[batch] = trainList[i].h;
batch_t[batch] = trainList[i].t;
batch_r[batch] = trainList[i].r;
batch_y[batch] = 1;
INT last = batchSize;
for (INT times = 0; times < negRate; times ++) {
if (bernFlag)
prob = 1000 * right_mean[trainList[i].r] / (right_mean[trainList[i].r] + left_mean[trainList[i].r]);
if (randd(id) % 1000 < prob) {
batch_h[batch + last] = trainList[i].h;
batch_t[batch + last] = corrupt_head(id, trainList[i].h, trainList[i].r);
batch_r[batch + last] = trainList[i].r;
} else {
batch_h[batch + last] = corrupt_tail(id, trainList[i].t, trainList[i].r);;
batch_t[batch + last] = trainList[i].t;
batch_r[batch + last] = trainList[i].r;
}
batch_y[batch + last] = -1;
last += batchSize;
}
for (INT times = 0; times < negRelRate; times++) {
batch_h[batch + last] = trainList[i].h;
batch_t[batch + last] = trainList[i].t;
batch_r[batch + last] = corrupt_rel(id, trainList[i].h, trainList[i].t);
batch_y[batch + last] = -1;
last += batchSize;
}
}
pthread_exit(NULL);
}

extern "C"
void sampling(INT *batch_h, INT *batch_t, INT *batch_r, REAL *batch_y, INT batchSize, INT negRate = 1, INT negRelRate = 0) {
pthread_t *pt = (pthread_t *)malloc(workThreads * sizeof(pthread_t));
Parameter *para = (Parameter *)malloc(workThreads * sizeof(Parameter));
for (INT threads = 0; threads < workThreads; threads++) {
para[threads].id = threads;
para[threads].batch_h = batch_h;
para[threads].batch_t = batch_t;
para[threads].batch_r = batch_r;
para[threads].batch_y = batch_y;
para[threads].batchSize = batchSize;
para[threads].negRate = negRate;
para[threads].negRelRate = negRelRate;
pthread_create(&pt[threads], NULL, getBatch, (void*)(para+threads));
}
for (INT threads = 0; threads < workThreads; threads++)
pthread_join(pt[threads], NULL);
free(pt);
free(para);
}

int main() {
importTrainFiles();
return 0;
}
Loading

0 comments on commit e6e0f0b

Please sign in to comment.