Skip to content

Commit

Permalink
remove random color given
Browse files Browse the repository at this point in the history
  • Loading branch information
orbxball committed Jun 7, 2017
1 parent a160668 commit d22171c
Showing 1 changed file with 4 additions and 24 deletions.
28 changes: 4 additions & 24 deletions hw6/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,17 @@
classes = ["Adventure|Western|Comedy", "Thriller|Horror|Mystery", "Crime|Film-Noir", "Sci-Fi|Fantasy", "Drama|Musical", "War|Documentary", "Children's|Animation", "Action|Romance"]

def parse_args():
parser = argparse.ArgumentParser(description='HW6: Matrix Factorization')
parser = argparse.ArgumentParser(description='HW6: drawing graph')
parser.add_argument('data_dir', type=str)
parser.add_argument('output', type=str)
return parser.parse_args()

def draw(mapping, filename):
print('Drawing...')
fig = plt.figure(figsize=(10, 10), dpi=200)
length = len(mapping.keys())
# X, Y, C = [], [], []
cm = plt.cm.get_cmap("jet", length)
for i, key in enumerate(mapping.keys()):
vis_x = mapping[key][:, 0]
vis_y = mapping[key][:, 1]
# color = [i] * vis_x.shape[0]
# X.append(vis_x)
# Y.append(vis_y)
# C.append(color)
np.random.seed(i*5)
plt.scatter(vis_x, vis_y, c=list(np.random.rand(3,)), marker='.', label=key)
# plt.scatter(X, Y, c=C, cmap=cm, marker='.', label=key)
plt.scatter(vis_x, vis_y, marker='.', label=key)
plt.xticks([])
plt.yticks([])
plt.legend(scatterpoints=1,
Expand All @@ -38,10 +29,7 @@ def draw(mapping, filename):
plt.tight_layout()
# plt.show()
fig.savefig(filename)

def predict_rating(trained_model, userid, movieid):
return rate(trained_model, userid - 1, movieid - 1)

print('Done drawing!')

def ensure_dir(file_path):
directory = os.path.dirname(file_path)
Expand Down Expand Up @@ -90,14 +78,6 @@ def main(args):
new_genres_map[c] = np.concatenate((new_genres_map[c], genres_map[g]), axis=0)
# print(new_genres_map[c].shape)
draw(new_genres_map, 'graph.png')
sys.exit(-1)

recommendations = pd.read_csv(TEST_CSV, usecols=['TestDataID'])
recommendations['Rating'] = test_data.apply(lambda x: predict_rating(trained_model, x['UserID'], x['MovieID']), axis=1)
# print(recommendations)

ensure_dir(args.output)
recommendations.to_csv(args.output, index=False, columns=['TestDataID', 'Rating'])


if __name__ == '__main__':
Expand Down

0 comments on commit d22171c

Please sign in to comment.