Skip to content

Commit

Permalink
Add seed
Browse files Browse the repository at this point in the history
  • Loading branch information
phiyodr committed Dec 21, 2022
1 parent c7f4985 commit 8d69a9a
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions multilabel_oversampling/multilabel_oversampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ def create_fake_data(size=1, seed=1):

class MultilabelOversampler:

def __init__(self, number_of_adds=1000, number_of_tries=100, tqdm_disable=False, details=False, plot=True):
def __init__(self, number_of_adds=1000, number_of_tries=100, seed=1, tqdm_disable=False, details=False, plot=True):
"""
Args:
number_of_add: Maximum number of new rows add to df. Total number of iterations.
number_of_tries: Maximum number of draws from df within total number of iterations.
seed: Seed for row sampling in `fit` function
tqdm_disable: Enable progress bar for each iteration.
details: Enable detailed feedback for each try
plot: Plot all tries (iteration vs. std) after process is finished.
Expand All @@ -50,9 +51,11 @@ def __init__(self, number_of_adds=1000, number_of_tries=100, tqdm_disable=False,
else:
self.number_of_tries = 1e6

self.seed = seed
self.tqdm_disable = tqdm_disable
self.details = details
self.plot = plot



def fit(self, df, target_list=["y1", "y2", "y3", "y4"]):
Expand All @@ -76,7 +79,7 @@ def fit(self, df, target_list=["y1", "y2", "y3", "y4"]):
# Take random row and add to df_new
not_working = []
for try_ in tqdm(range(self.number_of_tries), desc=f"Iter {iter_}", disable=True):
random_row = df.sample(n = 1)
random_row = df.sample(n = 1, random_state=self.seed)
df_interim = pd.concat((df_new, random_row))
new_std = df_interim[self.target_list].sum().std()
# If std improves add row, otherwise add to not_working list
Expand Down

0 comments on commit 8d69a9a

Please sign in to comment.