Skip to content

Commit

Permalink
add CustomDataFrameGroupBy
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigosnader committed Apr 16, 2023
1 parent 15ef8dc commit 2585a52
Showing 1 changed file with 71 additions and 60 deletions.
131 changes: 71 additions & 60 deletions src/wavy/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
warnings.simplefilter(action="ignore", category=FutureWarning)



def is_iterable(obj):
try:
iter(obj)
return True
except TypeError:
return False


def make_xy(df, lookback, horizon, gap):
x_frames = [frame for frame in df.iloc[:-horizon].rolling(lookback) if len(frame) == lookback]

Expand All @@ -28,40 +37,49 @@ def make_xy(df, lookback, horizon, gap):
return x_frames, y_frames


def make_xy_by_timestamp(df, lookback, horizon, gap):
# def make_xy_by_timestamp(df, lookback, horizon):

# sourcery skip: identity-comprehension
# ! datetime gaps are not supported yet
x_frames = [frame for frame in df[df.index < (df.index[-1] - pd.Timedelta(horizon))].rolling(lookback)]
# # sourcery skip: identity-comprehension
# # ! datetime gaps are not supported yet
# x_frames = [frame for frame in df[df.index < (df.index[-1] - pd.Timedelta(horizon))].rolling(lookback)]

y_frames = [frame for frame in df[df.index > (df.index[0] + pd.Timedelta(lookback))].rolling(horizon)]
# y_frames = [frame for frame in df[df.index > (df.index[0] + pd.Timedelta(lookback))].rolling(horizon)]

x_frames = x_frames[:len(y_frames)]
y_frames = y_frames[:len(x_frames)]
# x_frames = x_frames[:len(y_frames)]
# y_frames = y_frames[:len(x_frames)]

print(len(x_frames), len(y_frames))
# print(len(x_frames), len(y_frames))

return x_frames, y_frames
# return x_frames, y_frames


def get_ids(x_frames, y_frames):
assert len(x_frames) == len(y_frames)
return [str(uuid.uuid4()) for _ in range(len(x_frames))]


def create_panels(df, lookback, horizon, gap):
if isinstance(lookback, str) and isinstance(horizon, str):
x_frames, y_frames = make_xy_by_timestamp(df, lookback, horizon, gap)
elif isinstance(lookback, int) and isinstance(horizon, int):
x_frames, y_frames = make_xy(df, lookback, horizon, gap)
def create_panels(df, lookback, horizon, gap=None):
# if isinstance(lookback, str) and isinstance(horizon, str):
# if gap:
# raise ValueError("Gap is not supported for datetime lookback and horizon.")
# x_frames, y_frames = make_xy_by_timestamp(df, lookback, horizon)
# elif isinstance(lookback, int) and isinstance(horizon, int):
gap = gap or 1
x_frames, y_frames = make_xy(df, lookback, horizon, gap)
ids = get_ids(x_frames, y_frames)

# ?
# timesteps_name = df.index.name or "timesteps"

return Panel(pd.concat(x_frames, keys=ids)), Panel(pd.concat(y_frames, keys=ids))
x_panel = pd.concat(x_frames, keys=ids)
y_panel = pd.concat(y_frames, keys=ids)

def dropna_match(x, y):
x_panel.index.names = ['ids', 'timesteps']
y_panel.index.names = ['ids', 'timesteps']

return Panel(x_panel), Panel(y_panel)

def match(x, y):
"""
Drop frames with NaN in panels and match ids.
Expand All @@ -74,10 +92,10 @@ def dropna_match(x, y):
"""

x_t = x.drop_empty_frames()
y_t = y.match_frames(x_t)
y_t = y.match(x_t)

y_t = y_t.drop_empty_frames()
x_t = x_t.match_frames(y_t)
x_t = x_t.match(y_t)

return x_t, y_t

Expand Down Expand Up @@ -130,6 +148,13 @@ def _constructor_expanddim(self):
def _constructor(self):
return _PanelSeries

class CustomDataFrameGroupBy(DataFrameGroupBy):
# ! Does not work with ids
def __getitem__(self, key):
key = list(key) if is_iterable(key) else [key]
if isinstance(key[0], int):
key = self.obj.ids[list(key)]
return self.obj.loc[key]

class Panel(pd.DataFrame):
"""
Expand All @@ -141,6 +166,7 @@ def __init__(self, *args, **kw):
if len(args) == 1 and isinstance(args[0], Panel):
args[0]._copy_attrs(self)


_attributes_ = "train_size,test_size,val_size"

def _copy_attrs(self, df):
Expand Down Expand Up @@ -187,12 +213,13 @@ def num_columns(self) -> int:
"""Returns the number of columns in the panel."""
return self.shape_[2]


@property
def frames(self) -> DataFrameGroupBy:
def frames(self) -> CustomDataFrameGroupBy:
"""
Returns panel's frames.
"""
return self.groupby(level=0, as_index=True)
return CustomDataFrameGroupBy(self, self.groupby(level=0, as_index=True).grouper)

@property
def timesteps(self) -> pd.Int64Index:
Expand Down Expand Up @@ -246,10 +273,10 @@ def shape_(self) -> tuple[int, int, int]:
"""
return (len(self.ids), int(self.shape[0] / len(self.ids)), self.shape[1])

def row_(self, n: list[int] | int = 0) -> Panel:
def nth(self, n: list[int] | int = 0) -> Panel:
# ? rename with get_nth_rows?
"""
Returns the nth row of each frame.
Returns the nth row of each of a panel's frame.
Args:
n (``list[int]`` or ``int``): Row index.
Expand All @@ -260,26 +287,10 @@ def row_(self, n: list[int] | int = 0) -> Panel:
if all(n < -1 or n >= self.num_timesteps for n in n):
raise ValueError("n must be -1 or between 0 and the number of timesteps")

new_panel = self.groupby(level=0, as_index=False).nth(n)
new_panel = self.frames.nth(n)
self._copy_attrs(new_panel)
return new_panel

def get_timesteps(self, n: list[int] | int = 0) -> Panel:
"""
Returns the first timestep of each frame in the panel.
Args:
n (``list[int]`` or ``int``): Timestep to return.
"""

if isinstance(n, int):
n = [n]

if self.index.nlevels == 1:
return self.frames.take(n).index.get_level_values(0)

return self.frames.take(n).index.get_level_values(1)

@property
def values_(self) -> np.ndarray:
"""
Expand All @@ -299,9 +310,9 @@ def values_(self) -> np.ndarray:
"""
return np.reshape(self.to_numpy(), self.shape_)

def flatten_(self) -> pd.DataFrame:
def smash(self) -> pd.DataFrame:
"""
Flatten the panel.
Returns a DataFrame with the panel's frames smashed into a single frame.
"""

new_timesteps = np.resize(
Expand All @@ -315,10 +326,10 @@ def flatten_(self) -> pd.DataFrame:
panel = (
self.set_index(new_index)
.reset_index()
.pivot(index="id", columns=self.index.names[1])
.pivot(index="ids", columns=self.index.names[1])
)

columns = [f"{col}-{index}" for col, index in panel.columns.to_flat_index()]
columns = [f"{col}_{index}" for col, index in panel.columns.to_flat_index()]

panel.columns = columns

Expand All @@ -343,10 +354,10 @@ def drop_ids(self, ids: list[int] | int, inplace: bool = False) -> Panel | None:

def find_empty_frames(self) -> pd.Int64Index:
"""
Find NaN values index.
Find frames with all missing values.
Returns:
``List``: List with index of NaN frames.
``List``: List with index of empty frames.
"""
na = self.isna().any(axis=1)
return (
Expand All @@ -367,7 +378,7 @@ def drop_empty_frames(self, inplace: bool = False) -> Panel | None:
"""
return self.drop_ids(self.find_empty_frames(), inplace=inplace)

def match_frames(self, other: Panel, inplace: bool = False) -> Panel | None:
def match(self, other: Panel, inplace: bool = False) -> Panel | None:
"""
Match panel with other panel.
Expand Down Expand Up @@ -519,7 +530,7 @@ def test(self, value: np.ndarray) -> None:
raise ValueError("No testing set was set.")
self[-self.test_size * self.num_timesteps :] = value.values

def head_(self, n: int = 5) -> Panel:
def head_(self, n: int = 2) -> Panel:
"""
Return the first n frames of the panel.
Expand All @@ -531,7 +542,7 @@ def head_(self, n: int = 5) -> Panel:
"""
return self[: n * self.shape_[1]]

def tail_(self, n: int = 5) -> Panel:
def tail_(self, n: int = 2) -> Panel:
"""
Return the last n frames of the panel.
Expand Down Expand Up @@ -627,34 +638,34 @@ def sample_(
elif how == "spaced":
if hasattr(self, "train_size"):
train_ids = np.linspace(
self.train.ids[0],
self.train.ids[-1],
0,
self.train.shape_[0],
train_samples,
dtype=int,
endpoint=True,
endpoint=False,
)
val_ids = np.linspace(
self.val.ids[0],
self.val.ids[-1],
0,
self.val.shape_[0],
val_samples,
dtype=int,
endpoint=True,
endpoint=False,
)
test_ids = np.linspace(
self.test.ids[0],
self.test.ids[-1],
0,
self.test.shape_[0],
test_samples,
dtype=int,
endpoint=True,
endpoint=False,
)
else:
train_ids = np.linspace(
self.ids[0], self.ids[-1], train_samples, dtype=int, endpoint=True
0, self.shape_[0], train_samples, dtype=int, endpoint=False
)
val_ids = []
test_ids = []

new_panel = self.loc[[*train_ids, *val_ids, *test_ids]]
new_panel = self.frames[[*train_ids, *val_ids, *test_ids]]

# Reset ids
if reset_ids:
Expand Down Expand Up @@ -735,7 +746,7 @@ def plot_(
``plot``: Result of plot function.
"""

panel = self.row_(n=0)
panel = self.nth(n=0)
panel = panel.reset_index(level=0, drop=True)

if max and self.num_frames > max:
Expand Down

0 comments on commit 2585a52

Please sign in to comment.