Skip to content

Commit

Permalink
SocialBaseMF
Browse files Browse the repository at this point in the history
  • Loading branch information
ljy0ustc committed Jan 19, 2022
1 parent 20d3ae1 commit 5428bd9
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 0 deletions.
132 changes: 132 additions & 0 deletions exp3/src/SocialBaseMF/MF.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import numpy as np
import cmath

class Matrix_Factorization(object):

def __init__(self, K=10, alpha=0.1, beta=0.02, epoch=10, regularization=True, random_state=100):

self.R = None
self.K = K
self.P = None
self.Q = None
self.r_index = None
self.r = None
self.length = None
self.init_alpha=alpha
self.alpha = alpha
self.beta = beta
self.epoch = epoch
self.regularization = regularization
self.random_state = random_state


def fit(self, R):
#将矩阵R分解为M*K和K*N

np.random.seed(self.random_state)
self.R = R
M, N = self.R.shape
#P是M*K的随机矩阵,Q是K*N的随机矩阵
self.P = np.random.rand(M, self.K)
self.Q = np.random.rand(N, self.K)

#r_index是R中不为0的元素的数组下标
self.r_index = self.R.nonzero()
#r是R中不为0的元素值
self.r = self.R[self.r_index[0], self.r_index[1]]
self.length = len(self.r)


def _comp_descent(self, index):
#r中第index个元素
r_i = self.r_index[0][index]
r_j = self.r_index[1][index]

p_i = self.P[r_i]
q_j = self.Q[r_j]

r_ij_hat = p_i.dot(q_j)
e_ij = self.R[r_i, r_j] - r_ij_hat

#正则化?????????
if self.regularization == True:
descent_p_i = -2 * e_ij * q_j + self.beta * p_i
descent_q_j = -2 * e_ij * p_i + self.beta * q_j
else:
descent_p_i = -2 * e_ij * q_j
descent_q_j = -2 * e_ij * p_i

return r_i, r_j, p_i, q_j, descent_p_i, descent_q_j


def _update(self, p_i, q_j, descent_p_i, descent_q_j):

p_i_new = p_i - self.alpha * descent_p_i
q_j_new = q_j - self.alpha * descent_q_j

return p_i_new, q_j_new


def _estimate_r_hat(self):

r_hat = self.P.dot(self.Q.T)[self.r_index[0], self.r_index[1]]

return r_hat


def start(self):

epoch_num = 1
#epoch_cnt = 0
while epoch_num <= self.epoch:
for index in range(0, self.length):

r_i, r_j, p_i, q_j, descent_p_i, descent_q_j = self._comp_descent(index)
p_i_new, q_j_new = self._update(p_i, q_j, descent_p_i, descent_q_j)

self.P[r_i] = p_i_new
self.Q[r_j] = q_j_new

self.alpha=max(0.01,self.init_alpha*pow(0.85,epoch_num))
#if epoch_num%10==0:
# epoch_cnt+=1
# self.alpha=self.init_alpha/cmath.sqrt(epoch_cnt)
r_hat = self._estimate_r_hat()
e = r_hat - self.r
error = e.dot(e)
#if epoch_num%10==0:
# print ('The error is %s=================Epoch:%s' %(error, epoch_num))
print ('The error is %s=================Epoch:%s' %(error, epoch_num))
epoch_num += 1
if error<1:
break

R_hat = self.P.dot(self.Q.T)
return R_hat,self.P,self.Q

if __name__ == '__main__':

user_total=23599
item_total=21602
sep='\t'
comma=','
write_data_path=".//result//"

Rating=np.load('Rating.npy')

aa = Matrix_Factorization(K = 5)
aa.fit(Rating)
R_hat,P,Q=aa.start()

np.save('R_hat',R_hat)
np.save('P',P)
np.save('Q',Q)

RatingRank=np.argsort(R_hat)

with open(write_data_path+"res.txt","w") as f:
for userID in range(0,user_total):
output_str=str(userID)+sep
for music_index in range(item_total-100,item_total):
output_str+=str(RatingRank[userID][music_index])+comma
f.write(output_str[:-1]+'\n')
45 changes: 45 additions & 0 deletions exp3/src/SocialBaseMF/Social.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
if __name__ == '__main__':

user_total=23599
item_total=21602
sep='\t'
comma=','

read_data_path=".//data//"
f=open(read_data_path+"DoubanMusic.txt","r")
lines=f.readlines()
f.close()

followMap=[]
Rating=np.zeros((user_total,item_total),dtype="int32")
for line in lines:
temp=line.split()
UserID=int(temp[0])
for pair in temp[1:]:
pair=pair.split(',')
MusicID=int(pair[0])
MusicRating=int(pair[1])
Rating[UserID,MusicID]=1
followMap.append([])

f=open(read_data_path+"DoubanSocial.txt","r")
lines=f.readlines()
f.close()

for line in lines:
temp=line.split()
UserID=int(temp[0])
followID=int(temp[1])
followMap[UserID].append(followID)

for UserID in range(0,user_total):
outputstr=""
for ItemID in range(0,item_total):
if Rating[UserID][ItemID]==0:
for followID in followMap[UserID]:
if Rating[followID][ItemID]>0:
Rating[UserID][ItemID]=1
break

np.save('Rating',Rating)

0 comments on commit 5428bd9

Please sign in to comment.