Skip to content

Commit

Permalink
Fix df.sample bug
Browse files Browse the repository at this point in the history
  • Loading branch information
phiyodr committed Dec 22, 2022
1 parent 8d69a9a commit 7a89da8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,9 @@ print(df_new)
pip install git+https://github.com/phiyodr/multilabel-oversampling
```

:sunflower:

## :construction_worker:

* [] Implement weighted sampling (so that samples which are already often in the new df are less often sampled)

:sunflower:
29 changes: 19 additions & 10 deletions multilabel_oversampling/multilabel_oversampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def seed_everything(seed=1):
random.seed(seed)
np.random.seed(seed)

def create_fake_data(size=1, seed=1):
seed_everything(seed)
def create_fake_data(size=1):
#seed_everything(seed)
y1 = np.concatenate((np.ones(16*size), np.zeros(4*size))).astype(int)
y2 = np.concatenate((np.ones(12*size), np.zeros(8*size))).astype(int)
y3 = shuffle(np.concatenate((np.ones(4*size), np.zeros(16*size)))).astype(int)
Expand All @@ -30,7 +30,7 @@ def create_fake_data(size=1, seed=1):

class MultilabelOversampler:

def __init__(self, number_of_adds=1000, number_of_tries=100, seed=1, tqdm_disable=False, details=False, plot=True):
def __init__(self, number_of_adds=1000, number_of_tries=100, tqdm_disable=False, details=False, plot=False):
"""
Expand All @@ -51,7 +51,7 @@ def __init__(self, number_of_adds=1000, number_of_tries=100, seed=1, tqdm_disabl
else:
self.number_of_tries = 1e6

self.seed = seed
#self.seed = seed
self.tqdm_disable = tqdm_disable
self.details = details
self.plot = plot
Expand All @@ -72,14 +72,14 @@ def fit(self, df, target_list=["y1", "y2", "y3", "y4"]):
res_std = []
res_bad = []


print("Start the upsampling process.")
for iter_ in tqdm(range(self.number_of_adds),desc="Iteration", disable=self.tqdm_disable):
current_std = df_new[self.target_list].sum().std()

# Take random row and add to df_new
not_working = []
for try_ in tqdm(range(self.number_of_tries), desc=f"Iter {iter_}", disable=True):
random_row = df.sample(n = 1, random_state=self.seed)
random_row = df.sample(n = 1)
df_interim = pd.concat((df_new, random_row))
new_std = df_interim[self.target_list].sum().std()
# If std improves add row, otherwise add to not_working list
Expand All @@ -92,7 +92,11 @@ def fit(self, df, target_list=["y1", "y2", "y3", "y4"]):
else:
not_working.append((random_row.index[0], new_std))
if (try_+1) == self.number_of_tries:
print(f"No improvement after {self.number_of_tries} tries in iter {iter_}.")
print(f"Iter {iter_}: No improvement after {self.number_of_tries} tries.")
print(f"Sampling done.\n")
print(f"Dataset size original: {df.shape[0]}; Upsampled dataset size: {df_new.shape[0]}")
print(f"Original target distribution: {dict(zip(target_list, df[target_list].sum()))}")
print(f"Upsampled target distribution: {dict(zip(target_list, df_new[target_list].sum()))}")
break
res_bad.append(not_working)

Expand All @@ -114,7 +118,10 @@ def reset(self):

def plot_all_tries(self):
"""Plot for all iterations the returned std and the best value"""
y_max = max([x[1] for x in self.res_bad[0]]) * 1.1
try:
y_max = max([x[1] for x in self.res_bad[0]]) * 1.1
except:
y_max = self.res_std[0] * 1.025
plt.plot(self.res_std)
plt.scatter(range(len(self.res_std)), self.res_std)
plt.ylim(0, y_max)
Expand Down Expand Up @@ -159,8 +166,10 @@ def plot_index_counts(self, df_new):
return plt

if __name__ == '__main__':
df = create_fake_data(size=1, seed=3)
seed_everything(seed=42)
df = create_fake_data(size=1)
print(df)
mlo = MultilabelOversampler(number_of_adds=100)
mlo = MultilabelOversampler(number_of_adds=100, plot=True)
df_new = mlo.fit(df)
print(mlo.df_new)
mlo.plot_results()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setuptools.setup(
name="multilabel-oversampling",
version="0.1.0",
version="0.1.1",
author="Philipp J. Rösch",
author_email="[email protected]",
description="Multilabel Oversampling",
Expand Down

0 comments on commit 7a89da8

Please sign in to comment.