diff --git a/hw6/hw6.sh b/hw6/hw6.sh new file mode 100644 index 0000000..8a0c39d --- /dev/null +++ b/hw6/hw6.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +# $1: the data directory including test.csv, users.csv, movies.csv +# $2: prediction filename + +python3 test.py $1 $2 \ No newline at end of file diff --git a/hw6/test.py b/hw6/test.py index d81a0b1..abb3a47 100644 --- a/hw6/test.py +++ b/hw6/test.py @@ -57,7 +57,7 @@ def main(args): TEST_CSV = 'test.csv' USERS_CSV = 'users.csv' MOVIES_CSV = 'movies.csv' - MODEL_WEIGHTS_FILE = 'weights.h5' + MODEL_WEIGHTS_FILE = 'weights_cf_bias.h5' DATA_DIR = args.data_dir TEST_CSV = os.path.join(DATA_DIR, TEST_CSV)