From 3d6c92ae420e9b653d3573bd98bbb8fb55f306f3 Mon Sep 17 00:00:00 2001 From: orbxball Date: Fri, 2 Jun 2017 08:50:56 +0800 Subject: [PATCH] switch for best & normal CF --- hw6/Model.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/hw6/Model.py b/hw6/Model.py index 5337aef..e208469 100644 --- a/hw6/Model.py +++ b/hw6/Model.py @@ -5,25 +5,38 @@ from keras import backend as K from keras.regularizers import l2 -def build_cf_model(n_users, n_movies, dim): +def build_cf_model(n_users, n_movies, dim, isBest=False): u_input = Input(shape=(1,)) - u = Embedding(n_users, dim, embeddings_regularizer=l2(1e-5))(u_input) + if isBest: + u = Embedding(n_users, dim, embeddings_regularizer=l2(1e-5))(u_input) + else: + u = Embedding(n_users, dim)(u_input) u = Reshape((dim,))(u) u = Dropout(0.1)(u) m_input = Input(shape=(1,)) - m = Embedding(n_movies, dim, embeddings_regularizer=l2(1e-5))(m_input) + if isBest: + m = Embedding(n_movies, dim, embeddings_regularizer=l2(1e-5))(m_input) + else: + m = Embedding(n_movies, dim)(m_input) m = Reshape((dim,))(m) m = Dropout(0.1)(m) - u_bias = Embedding(n_users, 1, embeddings_regularizer=l2(1e-5))(u_input) + if isBest: + u_bias = Embedding(n_users, 1, embeddings_regularizer=l2(1e-5))(u_input) + else: + u_bias = Embedding(n_users, 1)(u_input) u_bias = Reshape((1,))(u_bias) - m_bias = Embedding(n_movies, 1, embeddings_regularizer=l2(1e-5))(m_input) + if isBest: + m_bias = Embedding(n_movies, 1, embeddings_regularizer=l2(1e-5))(m_input) + else: + m_bias = Embedding(n_movies, 1)(m_input) m_bias = Reshape((1,))(m_bias) out = dot([u, m], -1) out = add([out, u_bias, m_bias]) - out = Lambda(lambda x: x + K.constant(3.581712))(out) + if isBest: + out = Lambda(lambda x: x + K.constant(3.581712))(out) model = Model(inputs=[u_input, m_input], outputs=out) return model