Skip to content

Commit

Permalink
refac: modify Panel definition
Browse files Browse the repository at this point in the history
  • Loading branch information
ibiscp committed Jul 24, 2022
1 parent 2d72c80 commit ed1b46e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/wavy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def build(self):
units=self.y.num_timesteps * self.y.num_columns,
activation=self.last_activation,
),
Reshape(self.y_train.shape[1:]),
Reshape((self.y.num_columns,)),
]

self.model = Sequential(layers)
Expand Down
23 changes: 21 additions & 2 deletions src/wavy/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,30 @@ class Panel(pd.DataFrame):
# test_size = None
# val_size = None

_metadata = ["train_size", "test_size", "val_size"]
def __init__(self, *args, **kw):
super(Panel, self).__init__(*args, **kw)
if len(args) == 1 and isinstance(args[0], Panel):
args[0]._copy_attrs(self)

# _metadata = ["train_size", "test_size", "val_size"]
_attributes_ = "train_size,test_size,val_size"

def _copy_attrs(self, df):
for attr in self._attributes_.split(","):
df.__dict__[attr] = getattr(self, attr, None)

@property
def _constructor(self):
return Panel
def f(*args, **kw):
df = Panel(*args, **kw)
self._copy_attrs(df)
return df

return f

# @property
# def _constructor(self):
# return Panel

@property
def num_frames(self):
Expand Down

0 comments on commit ed1b46e

Please sign in to comment.