Skip to content

Commit

Permalink
Minor changes: Plotting, add __version__
Browse files Browse the repository at this point in the history
  • Loading branch information
phiyodr committed Dec 23, 2022
1 parent 7a89da8 commit 94249fc
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 12 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ pip install git+https://github.com/phiyodr/multilabel-oversampling
```


## :construction_worker:
## :construction_worker: Future work

* [] Implement weighted sampling (so that samples which are already often in the new df are less often sampled)
* [ ] Implement weighted sampling (so that samples which are already often in the new df are less often sampled)

:sunflower:
2 changes: 2 additions & 0 deletions multilabel_oversampling/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
__version__ = "0.1.2"

from .multilabel_oversampling import *
27 changes: 20 additions & 7 deletions multilabel_oversampling/multilabel_oversampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
import os
import collections
import math


Expand Down Expand Up @@ -137,14 +137,13 @@ def plot_results(self):
"""Plot target distribution before and after upsampling.
Also plot the counts of each index-id.
"""
plt.subplot(2,2,1)
plt.subplot(1,3,1)
self.plot_distr(self.df, "before")
plt.subplot(2,2,2)
plt.subplot(1,3,2)
self.plot_distr(self.df_new, "after")
plt.subplot(2,2,(3,4)) # MatplotlibDeprecationWarning
plt.subplot(1,3,3) # MatplotlibDeprecationWarning
self.plot_index_counts(self.df_new)
plt.tight_layout()
plt.show()
return plt

def plot_distr(self, df, when):
Expand All @@ -153,17 +152,31 @@ def plot_distr(self, df, when):
plt.title(f"Label distribution \n{when} upsampling")
return plt

def plot_index_counts(self, df_new):
def plot_individual_index_counts(self, df_new):
"""Plot upsampling counts for each index.
TODO make better xticks alignment"""
if df_new == None:
df_new == self.df_new
idxs = list(df_new.index)
lens = len(set(idxs))
plt.hist(idxs, bins=lens, width=.1)#, edgecolor='k')
xint = range(min(idxs), math.ceil(max(idxs))+1)
plt.xticks(xint)
plt.title("Draws per index\n in new df")
return plt


def plot_index_counts(self, df_new=None):
if df_new is None:
df_new = self.df_new
x = list(collections.Counter(list(df_new.index)).values())
plt.hist(x, bins=max(x)+1, rwidth=.9)
plt.title("Frequency of indexes in df")
plt.xlabel('Frequency in dataset')
plt.ylabel('Counts')
return plt


if __name__ == '__main__':
seed_everything(seed=42)
Expand All @@ -172,4 +185,4 @@ def plot_index_counts(self, df_new):
mlo = MultilabelOversampler(number_of_adds=100, plot=True)
df_new = mlo.fit(df)
print(mlo.df_new)
mlo.plot_results()
mlo.plot_results()
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import setuptools

print(setuptools.find_packages())
from multilabel_oversampling import __version__
print(__version__)

with open("README.md", "r") as file:
long_description = file.read()
Expand All @@ -10,7 +10,7 @@

setuptools.setup(
name="multilabel-oversampling",
version="0.1.1",
version=__version__,
author="Philipp J. Rösch",
author_email="[email protected]",
description="Multilabel Oversampling",
Expand Down

0 comments on commit 94249fc

Please sign in to comment.