Skip to content

Commit

Permalink
switch for best & normal CF
Browse files Browse the repository at this point in the history
  • Loading branch information
orbxball committed Jun 2, 2017
1 parent ef81619 commit 3d6c92a
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions hw6/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,38 @@
from keras import backend as K
from keras.regularizers import l2

def build_cf_model(n_users, n_movies, dim):
def build_cf_model(n_users, n_movies, dim, isBest=False):
u_input = Input(shape=(1,))
u = Embedding(n_users, dim, embeddings_regularizer=l2(1e-5))(u_input)
if isBest:
u = Embedding(n_users, dim, embeddings_regularizer=l2(1e-5))(u_input)
else:
u = Embedding(n_users, dim)(u_input)
u = Reshape((dim,))(u)
u = Dropout(0.1)(u)

m_input = Input(shape=(1,))
m = Embedding(n_movies, dim, embeddings_regularizer=l2(1e-5))(m_input)
if isBest:
m = Embedding(n_movies, dim, embeddings_regularizer=l2(1e-5))(m_input)
else:
m = Embedding(n_movies, dim)(m_input)
m = Reshape((dim,))(m)
m = Dropout(0.1)(m)

u_bias = Embedding(n_users, 1, embeddings_regularizer=l2(1e-5))(u_input)
if isBest:
u_bias = Embedding(n_users, 1, embeddings_regularizer=l2(1e-5))(u_input)
else:
u_bias = Embedding(n_users, 1)(u_input)
u_bias = Reshape((1,))(u_bias)
m_bias = Embedding(n_movies, 1, embeddings_regularizer=l2(1e-5))(m_input)
if isBest:
m_bias = Embedding(n_movies, 1, embeddings_regularizer=l2(1e-5))(m_input)
else:
m_bias = Embedding(n_movies, 1)(m_input)
m_bias = Reshape((1,))(m_bias)

out = dot([u, m], -1)
out = add([out, u_bias, m_bias])
out = Lambda(lambda x: x + K.constant(3.581712))(out)
if isBest:
out = Lambda(lambda x: x + K.constant(3.581712))(out)

model = Model(inputs=[u_input, m_input], outputs=out)
return model
Expand Down

0 comments on commit 3d6c92a

Please sign in to comment.