Skip to content

Commit

Permalink
rm average in testing
Browse files Browse the repository at this point in the history
  • Loading branch information
orbxball committed May 12, 2017
1 parent dc36c61 commit c3b76b6
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions hw4/dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,15 @@ def validate(model, sample_size):

def test(model, test_data):
ans = []
test_round_size = 10
total_avg = 0.0
for i in test_data.keys():
print('Predicting {}...'.format(i))
for j in range(test_round_size):
sample_test_data = sampling(test_data[i].shape[0], test_data[i])
distances, indices = NN(test_data[i], sample_test_data)

avg_d = np.mean(distances[:,1])
total_avg += avg_d
print('Round: {}, avg_d: {}'.format(j, avg_d))
total_avg /= test_round_size
predicted = get_index(total_avg, model) + 1
sample_test_data = sampling(test_data[i].shape[0], test_data[i])
distances, indices = NN(test_data[i], sample_test_data)

avg_d = np.mean(distances[:,1])

predicted = get_index(avg_d, model) + 1

ans.append(np.log(predicted))
return ans
Expand Down

0 comments on commit c3b76b6

Please sign in to comment.