Skip to content

Commit

Permalink
'彻底按照item推荐进行修改,即修改MHCN,sampler'
Browse files Browse the repository at this point in the history
  • Loading branch information
Estelle-gqy committed Feb 23, 2023
2 parents c0efaf9 + f9f0155 commit 693eda3
Show file tree
Hide file tree
Showing 12 changed files with 542,530 additions and 305,836 deletions.
315 changes: 166 additions & 149 deletions .idea/workspace.xml

Large diffs are not rendered by default.

57 changes: 36 additions & 21 deletions data/retweet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
import os, math, tqdm, time
import re, jieba
from datetime import datetime
from collections import defaultdict

class Dataset(object):
def __init__(self, data, staff_df):
self.data = self.sort_by_time(data)
self.staff_df = staff_df # 作为构建社交关系的补充数据源
self.user = {}
self.examiner = defaultdict(dict)
self.title = {}
self.id2user = dict()
self.id2title = dict()
self.interaction_set = set()
self.initialize(self.data)
# self.flatten_record(self.data, self.records_save_path)
# self.dataset_split()
# self.get_interaction_dataset(self.interaction_save_path)


def initialize(self, data):
print('Initializing data......')
Expand Down Expand Up @@ -46,14 +46,29 @@ def initialize(self, data):

# 下一环节处理人id、名称提取
if isinstance(next_step_worker, str):
worker_list = next_step_worker.split(',')
temp = next_step_worker.split(',')
self.data.at[idx, '下一环节处理人'] = temp
worker_list = self.data.loc[idx, '下一环节处理人']
elif math.isnan(next_step_worker) or next_step_worker == '':
worker_list = ['null_token']
self.data.at[idx, '下一环节处理人'] = ["null_token"]
worker_list = self.data.loc[idx, '下一环节处理人']
for man in worker_list:
if man.strip() not in self.user:
self.user[man] = len(self.user)
self.id2user[self.user[man]] = man

for idx in tqdm.tqdm(range(self.data.shape[0])):
examiner = self.data.loc[idx, ][3]
# 提取审批者,曾经派发过的员工{examiner1:{worker1 : 1, worker2 : 1}, examiner2 : {worker3:1}}
if examiner not in self.examiner:
related_workers = self.data[self.data['审批人UID'] == examiner]['下一环节处理人'].to_list()
worker_list = set()
for rw in related_workers:
worker_list = worker_list.union(set(rw))

for man in worker_list:
self.examiner[self.user[examiner]][self.user[man]] = 1

def sort_by_time(self, records_data_df):
# !!按照审批时间排序,之前的数据做训练集、验证集,近的数据做测试集
time_col = records_data_df["审批时间"].to_list()
Expand All @@ -67,26 +82,26 @@ def sort_by_time(self, records_data_df):
def flatten_record(self, records_save_path):
'''
将处理记录转化成一个审批人对应一个下以环节处理人的形式,1 to N ---> 1 to 1
生成retweet_records.txt文档,每一行包括(用户id、公文id、公文标题、rating)
生成retweet_records.txt文档,每一行包括(审批人id、公文id、公文标题、处理人id、rating)
:param records_save_path
:return 处理后的dataframe
:return 处理记录txt文档
'''
print(self.data.info())
print('Preprocessing records.....')
self.records_save_path = records_save_path

with open(self.records_save_path, 'w', encoding='utf-8') as fp1:
with open(self.records_save_path, 'w', encoding='utf-8') as fp:
for idx in tqdm.tqdm(range(self.data.shape[0])):
next_step_worker, one_title, examiner = self.data.loc[idx, ][0], self.data.loc[idx, ][1], self.data.loc[idx, ][3]
fp1.write(str(self.user[examiner]) + ' ' + str(self.title[one_title]) + ' ' + one_title + ' 1\n') # (用户id、公文id、公文标题、rating)
next_step_worker, one_title, examiner = self.data.loc[idx, ][0], self.data.loc[idx, ][1], \
self.data.loc[idx, ][3]

# 找出审批者曾经派发过文件的所有员工,不在worker list的标签为0,在的标签为1
for worker in self.examiner[self.user[examiner]].keys():
if self.id2user[worker] in next_step_worker: # 此时worker list已经是
fp.write(str(self.user[examiner]) + ' ' + str(self.title[one_title]) + ' ' + one_title + ' '+ str(worker) + ' 1\n') # (审批人id、公文id、公文标题、处理人id、rating)
else:
fp.write(str(self.user[examiner]) + ' ' + str(self.title[one_title]) + ' ' + one_title + ' '+ str(worker) + ' 0\n') # (审批人id、公文id、公文标题、处理人id、rating)

# 下一环节处理人id、名称提取
if isinstance(next_step_worker, str):
worker_list = next_step_worker.split(',')
elif math.isnan(next_step_worker) or next_step_worker == '':
worker_list = ['null_token']
for man in worker_list:
fp1.write(str(self.user[man]) + ' ' + str(self.title[one_title]) + ' ' + one_title + ' 1\n')

def get_interaction_dataset(self, interaction_save_path, records_save_path, train_frac = 0.8):
# return: 每一行包括(user1, user2) 代表两人是好友
Expand Down Expand Up @@ -240,23 +255,23 @@ def main():
# 加载初始文档
data = pd.read_excel('../dataset/retweet_prediction/retweet_records.xlsx')
staff = pd.read_excel('../dataset/retweet_prediction/staff.xlsx')
records_save_path = '../dataset/retweet_prediction/retweet_records.txt'
records_save_path = '../dataset/retweet_prediction/retweets.txt'
interaction_save_path = '../dataset/retweet_prediction/trust.txt'
train_save_path = "../dataset/retweet_prediction/train.txt"
ds = Dataset(data, staff)

# 一对多 转化为 一对一
flattened = True
if not flattened:
ds.flatten_record(records_save_path, interaction_save_path )
ds.flatten_record(records_save_path)

# split the dataset
split = True
split = False
if flattened and not split:
ds.dataset_split(records_save_path)

# 生成关系网
get_interaction = False
get_interaction = True
if not get_interaction and split and flattened:
ds.get_interaction_dataset(interaction_save_path, train_save_path)

Expand Down
Loading

0 comments on commit 693eda3

Please sign in to comment.