Skip to content

Commit

Permalink
Fix png plots
Browse files Browse the repository at this point in the history
  • Loading branch information
phiyodr committed Dec 21, 2022
1 parent cb8d9fc commit 5bea95c
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ from multilabel_oversampling import multilabel_oversampling as mo

df = mo.create_fake_data(size=1, seed=3)
ml_oversampler = mo.MultilabelOversampler(number_of_adds=100, number_of_tries=100)
df_new = ml_oversampler.fit(df)
df_new, plot_at = ml_oversampler.fit(df)
#> Iteration: 20%|██████ | 20/100 [00:00<00:00, 111.68it/s]
#> No improvement after 100 tries in iter 20.
```
Expand Down
Binary file modified assets/plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified assets/plot_results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 7 additions & 3 deletions multilabel_oversampling/multilabel_oversampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ def fit(self, df, target_list=["y1", "y2", "y3", "y4"]):
self.res_std = res_std
self.res_bad = res_bad
if (len(res_std) > 0) and self.plot:
self.plot_all_tries(self.res_std, self.res_bad)
plot_at = self.plot_all_tries(self.res_std, self.res_bad)
plt.title("All tries per iteration with \n corresponding standard deviation")
plt.show()
return df_new, plot_at
return df_new

def reset(self):
Expand All @@ -122,6 +124,7 @@ def plot_all_tries(res_std, res_bad):
plt.scatter(i + idx*0.01, s)
plt.xlabel('Iters')#, fontsize=18)
plt.ylabel('Std')#, fontsize=16)
return plt

def plot_results(self):
plt.subplot(2,2,1)
Expand All @@ -132,7 +135,8 @@ def plot_results(self):
self.plot_index_counts(self.df_new)
plt.tight_layout()
plt.show()

return plt

def plot_distr(self, df, when):
df[self.target_list].sum().plot.bar()
plt.title(f"Label distribution \n{when} upsampling")
Expand All @@ -151,6 +155,6 @@ def plot_index_counts(self, df_new):
if __name__ == '__main__':
df = create_fake_data(size=1, seed=3)
print(df)
mlo = MultilabelOversampling(number_of_adds=100)
mlo = MultilabelOversampler(number_of_adds=100)
df_new = mlo.fit(df)
mlo.plot_results()

0 comments on commit 5bea95c

Please sign in to comment.