Skip to content

Commit

Permalink
Nicer plot
Browse files Browse the repository at this point in the history
  • Loading branch information
phiyodr committed Dec 21, 2022
1 parent 22803a6 commit c7f4985
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions multilabel_oversampling/multilabel_oversampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,13 @@ def fit(self, df, target_list=["y1", "y2", "y3", "y4"]):
print(f"No improvement after {self.number_of_tries} tries in iter {iter_}.")
break
res_bad.append(not_working)
#plt.plot(res_std)
#plt.show()
#df_new.sum().plot.bar()

self.df_new = df_new
self.res_std = res_std
self.res_bad = res_bad
if (len(res_std) > 0) and self.plot:
plot_at = self.plot_all_tries(self.res_std, self.res_bad)
plot_at = self.plot_all_tries()
plt.title("All tries per iteration with \n corresponding standard deviation")
plt.show()
return df_new, plot_at
return df_new

Expand All @@ -112,13 +109,13 @@ def reset(self):
self.res_std = None
self.res_bad = None

@staticmethod
def plot_all_tries(res_std, res_bad):
y_max = max([x[1] for x in res_bad[0]]) * 1.1
plt.plot(res_std)
plt.scatter(range(len(res_std)), res_std)
def plot_all_tries(self):
"""Plot for all iterations the returned std and the best value"""
y_max = max([x[1] for x in self.res_bad[0]]) * 1.1
plt.plot(self.res_std)
plt.scatter(range(len(self.res_std)), self.res_std)
plt.ylim(0, y_max)
for i, row_std in enumerate(res_bad):
for i, row_std in enumerate(self.res_bad):
for idx, (j, s) in enumerate(row_std):
#plt.text(i + idx*0.02, s, f"{j}", fontsize=8)
plt.scatter(i + idx*0.01, s)
Expand All @@ -127,6 +124,9 @@ def plot_all_tries(res_std, res_bad):
return plt

def plot_results(self):
"""Plot target distribution before and after upsampling.
Also plot the counts of each index-id.
"""
plt.subplot(2,2,1)
self.plot_distr(self.df, "before")
plt.subplot(2,2,2)
Expand All @@ -138,12 +138,15 @@ def plot_results(self):
return plt

def plot_distr(self, df, when):
"""Plot target distribtion"""
df[self.target_list].sum().plot.bar()
plt.title(f"Label distribution \n{when} upsampling")
return plt

def plot_index_counts(self, df_new):
"""TODO make better xticks alignment"""
"""Plot upsampling counts for each index.
TODO make better xticks alignment"""
idxs = list(df_new.index)
lens = len(set(idxs))
plt.hist(idxs, bins=lens, width=.1)#, edgecolor='k')
Expand Down

0 comments on commit c7f4985

Please sign in to comment.