Skip to content

Commit

Permalink
feat: add seed
Browse files Browse the repository at this point in the history
  • Loading branch information
ibiscp committed Aug 1, 2022
1 parent 5a8af63 commit 2694017
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/wavy/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ def sample_panel(
how: str = "spaced",
reset_ids: bool = False,
inplace: bool = False,
seed: int = 42,
) -> Optional[Panel]:
"""
Sample panel returning a subset of frames.
Expand All @@ -582,12 +583,15 @@ def sample_panel(
samples (int or float): Number or percentage of samples to return.
how (str): Sampling method, 'spaced' or 'random'
reset_ids (bool): If True, reset the index of the sampled panel.
inplace (bool): If True, perform operation in-place.
seed (int): Random seed.
Returns:
``Panel``: Result of sample function.
"""

# TODO fix if no set train split
# TODO add seed

train_samples, val_samples, test_samples = _validate_sample_panel(
samples=samples,
Expand All @@ -596,6 +600,9 @@ def sample_panel(
test_size=self.test_size,
)

# Set seed
np.random.seed(seed)

if how == "random":
train_ids = sorted(
np.random.choice(self.train.ids, train_samples, replace=False)
Expand Down Expand Up @@ -697,6 +704,8 @@ def plot(
Args:
add_annotation (bool): If True, plot the training, validation, and test annotation.
max (int): Maximum number of samples to plot.
use_timestep (bool): If True, plot the timestep instead of the sample index.
**kwargs: Additional arguments to pass to the plot function.
Returns:
Expand Down

0 comments on commit 2694017

Please sign in to comment.