diff --git a/multilabel_oversampling/multilabel_oversampling.py b/multilabel_oversampling/multilabel_oversampling.py index 5ff45cb..55b6aa6 100644 --- a/multilabel_oversampling/multilabel_oversampling.py +++ b/multilabel_oversampling/multilabel_oversampling.py @@ -92,16 +92,13 @@ def fit(self, df, target_list=["y1", "y2", "y3", "y4"]): print(f"No improvement after {self.number_of_tries} tries in iter {iter_}.") break res_bad.append(not_working) - #plt.plot(res_std) - #plt.show() - #df_new.sum().plot.bar() + self.df_new = df_new self.res_std = res_std self.res_bad = res_bad if (len(res_std) > 0) and self.plot: - plot_at = self.plot_all_tries(self.res_std, self.res_bad) + plot_at = self.plot_all_tries() plt.title("All tries per iteration with \n corresponding standard deviation") - plt.show() return df_new, plot_at return df_new @@ -112,13 +109,13 @@ def reset(self): self.res_std = None self.res_bad = None - @staticmethod - def plot_all_tries(res_std, res_bad): - y_max = max([x[1] for x in res_bad[0]]) * 1.1 - plt.plot(res_std) - plt.scatter(range(len(res_std)), res_std) + def plot_all_tries(self): + """Plot for all iterations the returned std and the best value""" + y_max = max([x[1] for x in self.res_bad[0]]) * 1.1 + plt.plot(self.res_std) + plt.scatter(range(len(self.res_std)), self.res_std) plt.ylim(0, y_max) - for i, row_std in enumerate(res_bad): + for i, row_std in enumerate(self.res_bad): for idx, (j, s) in enumerate(row_std): #plt.text(i + idx*0.02, s, f"{j}", fontsize=8) plt.scatter(i + idx*0.01, s) @@ -127,6 +124,9 @@ def plot_all_tries(res_std, res_bad): return plt def plot_results(self): + """Plot target distribution before and after upsampling. + Also plot the counts of each index-id. + """ plt.subplot(2,2,1) self.plot_distr(self.df, "before") plt.subplot(2,2,2) @@ -138,12 +138,15 @@ def plot_results(self): return plt def plot_distr(self, df, when): + """Plot target distribtion""" df[self.target_list].sum().plot.bar() plt.title(f"Label distribution \n{when} upsampling") return plt def plot_index_counts(self, df_new): - """TODO make better xticks alignment""" + """Plot upsampling counts for each index. + + TODO make better xticks alignment""" idxs = list(df_new.index) lens = len(set(idxs)) plt.hist(idxs, bins=lens, width=.1)#, edgecolor='k')