Skip to content

Commit

Permalink
refac: improve sample_panel
Browse files Browse the repository at this point in the history
  • Loading branch information
ibiscp committed Jul 31, 2022
1 parent b595d93 commit 83322d5
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 14 deletions.
74 changes: 60 additions & 14 deletions src/wavy/panel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import random
import warnings
from itertools import chain
Expand All @@ -7,7 +8,7 @@
import pandas as pd

from wavy.plot import plot
from wavy.validations import _validate_training_split
from wavy.validations import _validate_sample_panel, _validate_training_split


def create_panels(df, lookback: int, horizon: int, gap: int = 0):
Expand Down Expand Up @@ -174,12 +175,10 @@ def _copy_attrs(self, df):
def _constructor(self):
def f(*args, **kw):

try:
with contextlib.suppress(Exception):
index = [a for a in args[0].axes if isinstance(a, pd.MultiIndex)]
if index and len(index[0]) == self.num_timesteps:
return pd.DataFrame(*args, **kw)
except:
pass

df = Panel(*args, **kw)

Expand Down Expand Up @@ -269,7 +268,9 @@ def row_panel(self, n: int = 0):
if n < -1 or n >= self.num_timesteps:
raise ValueError("n must be -1 or between 0 and the number of timesteps")

return self.groupby(level=0, as_index=False).nth(n)
new_panel = self.groupby(level=0, as_index=False).nth(n)
self._copy_attrs(new_panel)
return new_panel

def get_timesteps(self, n: Union[list, int] = 0):
"""
Expand Down Expand Up @@ -549,28 +550,73 @@ def sort_panel(
key=key,
)

def sample_panel(self, samples: int = 5, how: str = "spaced"):
def sample_panel(
self,
samples: Union[int, float] = 5,
how: str = "spaced",
reset_ids: bool = False,
):
"""
Sample panel returning a subset of frames.
Args:
samples (int): Number of samples to keep
samples (int or float): Number or percentage of samples to return.
how (str): Sampling method, 'spaced' or 'random'
reset_ids (bool): If True, reset the index of the sampled panel.
Returns:
``Panel``: Result of sample function.
"""

train_samples, val_samples, test_samples = _validate_sample_panel(
samples=samples,
train_size=self.train_size,
val_size=self.val_size,
test_size=self.test_size,
)

if how == "random":
warnings.warn("Random sampling can result in data leakage.")
indexes = np.random.choice(self.ids, samples, replace=False)
indexes = sorted(indexes)
return self.get_frame_by_ids(indexes)
train_ids = sorted(
np.random.choice(self.train.ids, train_samples, replace=False)
)
val_ids = sorted(np.random.choice(self.val.ids, val_samples, replace=False))
test_ids = sorted(
np.random.choice(self.test.ids, test_samples, replace=False)
)

elif how == "spaced":
indexes = np.linspace(
self.ids[0], self.ids[-1], samples, dtype=int, endpoint=False
train_ids = np.linspace(
self.train.ids[0],
self.train.ids[-1],
train_samples,
dtype=int,
endpoint=True,
)
return self.get_frame_by_ids(indexes)
val_ids = np.linspace(
self.val.ids[0],
self.val.ids[-1],
val_samples,
dtype=int,
endpoint=True,
)
test_ids = np.linspace(
self.test.ids[0],
self.test.ids[-1],
test_samples,
dtype=int,
endpoint=True,
)

new_panel = concat_panels(
[self.loc[train_ids], self.loc[val_ids], self.loc[test_ids]],
reset_ids=reset_ids,
)

new_panel.train_size = train_samples
new_panel.val_size = val_samples
new_panel.test_size = test_samples

return new_panel

def shuffle_panel(self, seed: int = None):
"""
Expand Down
35 changes: 35 additions & 0 deletions src/wavy/validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,38 @@ def _validate_training_split(n_samples, train_size, val_size, test_size):
)

return int(n_train), int(n_val), int(n_test)


def _validate_sample_panel(samples, train_size, val_size, test_size):
"""
Validation helper to check if the samples size is meaningful wrt to the
size of the data (n_samples)
"""

n_samples = train_size + val_size + test_size
samples_type = np.asarray(samples).dtype.kind

# Check samples size
if (
samples_type == "i"
and (samples >= n_samples or samples <= 0)
or samples_type == "f"
and (samples <= 0 or samples >= 1)
):
raise ValueError(
"samples={0} should be either positive and smaller"
" than the number of samples {1} or a float in the "
"(0, 1) range".format(samples, n_samples)
)

if samples is not None and samples_type not in ("i", "f"):
raise ValueError(f"Invalid value for samples: {samples}")

if samples_type == "i":
samples = samples / n_samples

train_samples = round(samples * train_size)
val_samples = round(samples * val_size)
test_samples = round(samples * test_size)

return train_samples, val_samples, test_samples

0 comments on commit 83322d5

Please sign in to comment.