Skip to content

Commit

Permalink
split M-step out. verify labels
Browse files Browse the repository at this point in the history
  • Loading branch information
JiaweiZhuang committed Apr 16, 2017
1 parent a9cbc2f commit dc865c5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
23 changes: 14 additions & 9 deletions Parallel_Algorithm/OpenMP/Kmean_seq.c
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,19 @@ int main() {
labels[i] = k_best;
dist_sum_new += dist_min;

// M-Step (half): set the cluster centers to the mean
cluster_sizes[k_best]++; // add one more points to this cluster
// As the total number of samples in each cluster is not known yet,
// here we are just calculating the sum, not the mean.
for (j=0; j<N_features; j++)
new_cluster_centers[k_best][j] += X[i][j];
}

} //end if E-Step and half M-Step
// M-Step first half: set the cluster centers to the mean
for (i = 0; i < N_samples; i++) {
k_best = labels[i];
cluster_sizes[k_best]++; // add one more points to this cluster
// As the total number of samples in each cluster is not known yet,
// here we are just calculating the sum, not the mean.
for (j=0; j<N_features; j++)
new_cluster_centers[k_best][j] += X[i][j];
}

// M-Step-continued: convert the sum to the mean
// M-Step second half: convert the sum to the mean
for (k=0; k<N_clusters; k++) {
for (j=0; j<N_features; j++) {

Expand All @@ -229,10 +232,12 @@ int main() {
//printf("Final inertia: %f, iteration: %d \n",dist_sum_new,i_iter);

// record the best results
if (dist_sum_new < inert_best)
if (dist_sum_new < inert_best) {
inert_best = dist_sum_new;
for (i = 0; i < N_samples; i++)
labels_best[i] = labels[i];
}

} //end of one repeated run
double iElaps2 = seconds() - iStart2;

Expand Down
10 changes: 10 additions & 0 deletions Parallel_Algorithm/python_reference/check_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import xarray as xr

dirname = "../test_data/"
filename = "Blobs_smp20000_fea30_cls8.nc"

with xr.open_dataset(dirname+filename) as ds:
mismatch = (ds["Y_Py"].values != ds["Y_C"].values)

print("total number of samples: ",mismatch.size)
print("inconsistent labels: ",mismatch.sum())

0 comments on commit dc865c5

Please sign in to comment.