Skip to content

Commit

Permalink
refac: fix error with apply_panel
Browse files Browse the repository at this point in the history
  • Loading branch information
ibiscp committed Jul 27, 2022
1 parent ed1b46e commit ecde88c
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions src/wavy/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def create_panels(df, lookback: int, horizon: int, gap: int = 0):


def reset_ids(x, y, inplace=False):
# TODO check inplaces
"""
Reset ids of a panel.
Expand Down Expand Up @@ -161,19 +162,11 @@ def set_training_split(


class Panel(pd.DataFrame):
# def __init__(self, *args, **kwargs):
# super(Panel, self).__init__(*args, **kwargs)

# train_size = None
# test_size = None
# val_size = None

def __init__(self, *args, **kw):
super(Panel, self).__init__(*args, **kw)
if len(args) == 1 and isinstance(args[0], Panel):
args[0]._copy_attrs(self)

# _metadata = ["train_size", "test_size", "val_size"]
_attributes_ = "train_size,test_size,val_size"

def _copy_attrs(self, df):
Expand All @@ -189,10 +182,6 @@ def f(*args, **kw):

return f

# @property
# def _constructor(self):
# return Panel

@property
def num_frames(self):
"""Returns the number of frames in the panel."""
Expand Down Expand Up @@ -221,7 +210,14 @@ def __getattr__(self, name):
name = name.replace("_panel", "")

def wrapper(*args, **kwargs):
panel = self.groupby(level=0).apply(name, *args, **kwargs)

if name != 'apply':
panel = self.groupby(level=0).apply(name, *args, **kwargs)
else:
args = list(args)
new_name = args.pop(0)
args = tuple(args)
panel = self.groupby(level=0).apply(new_name, *args, **kwargs)

# Update ids
ids = panel.index.get_level_values(0)
Expand All @@ -240,6 +236,9 @@ def wrapper(*args, **kwargs):

@property
def ids(self):
"""
Returns the ids of the panel.
"""
return self.index.get_level_values(0).drop_duplicates()

@ids.setter
Expand Down Expand Up @@ -267,7 +266,6 @@ def reset_ids(self, inplace=False):
Args:
inplace (bool): Whether to reset ids inplace.
"""
# self.ids = np.arange(self.num_frames)
new_ids = np.repeat(np.arange(self.num_frames), self.num_timesteps)
new_index = pd.MultiIndex.from_arrays(
[new_ids, self.index.get_level_values(1)],
Expand All @@ -280,6 +278,8 @@ def reset_ids(self, inplace=False):
def shape_panel(self):
return (len(self.ids), int(self.shape[0] / len(self.ids)), self.shape[1])

# TODO add loc, iloc

def row_panel(self, n: int = 0):
"""
Returns the nth row of each frame.
Expand Down Expand Up @@ -636,6 +636,7 @@ def plot(self, add_annotation=True, max=10_000, use_timestep=False, **kwargs):
``plot``: Result of plot function.
"""

# TODO Fix annotation when size is higher than 10_000
if max and self.num_frames > max:
return plot(
self.sample_panel(max, how="spaced"),
Expand Down

0 comments on commit ecde88c

Please sign in to comment.