Skip to content

Commit

Permalink
dropout=0.1; default dimension=15; epochs=1000
Browse files Browse the repository at this point in the history
  • Loading branch information
orbxball committed Jun 1, 2017
1 parent 5000ed0 commit ef81619
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 2 additions & 0 deletions hw6/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ def build_cf_model(n_users, n_movies, dim):
u_input = Input(shape=(1,))
u = Embedding(n_users, dim, embeddings_regularizer=l2(1e-5))(u_input)
u = Reshape((dim,))(u)
u = Dropout(0.1)(u)

m_input = Input(shape=(1,))
m = Embedding(n_movies, dim, embeddings_regularizer=l2(1e-5))(m_input)
m = Reshape((dim,))(m)
m = Dropout(0.1)(m)

u_bias = Embedding(n_users, 1, embeddings_regularizer=l2(1e-5))(u_input)
u_bias = Reshape((1,))(u_bias)
Expand Down
4 changes: 2 additions & 2 deletions hw6/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def parse_args():
parser = argparse.ArgumentParser(description='HW6: Matrix Factorization')
parser.add_argument('train', type=str)
parser.add_argument('test', type=str)
parser.add_argument('--dim', type=int, default=120)
parser.add_argument('--dim', type=int, default=15)
return parser.parse_args()


Expand Down Expand Up @@ -49,7 +49,7 @@ def main(args):

callbacks = [EarlyStopping('val_rmse', patience=2),
ModelCheckpoint(MODEL_WEIGHTS_FILE, save_best_only=True)]
history = model.fit([Users, Movies], Ratings, epochs=100, batch_size=256, validation_split=.1, verbose=1, callbacks=callbacks)
history = model.fit([Users, Movies], Ratings, epochs=1000, batch_size=256, validation_split=.1, verbose=1, callbacks=callbacks)


if __name__ == '__main__':
Expand Down

0 comments on commit ef81619

Please sign in to comment.