Skip to content

Commit

Permalink
refac: bug fix and tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
ibiscp committed Jul 30, 2022
1 parent 9c3ce88 commit 9e6f6e5
Show file tree
Hide file tree
Showing 5 changed files with 3,075 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
data
notebooks
future
tutorials
tutorials_old

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
6 changes: 3 additions & 3 deletions src/wavy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,9 @@ def __init__(

def set_arrays(self):
"""Set the arrays."""
self.x_train = self.y.train.shift(self.shift).fillna(self.fillna).values
self.x_val = self.y.val.shift(self.shift).fillna(self.fillna).values
self.x_test = self.y.test.shift(self.shift).fillna(self.fillna).values
self.x_train = self.y.train.shift_panel(self.shift).fillna(self.fillna).values
self.x_val = self.y.val.shift_panel(self.shift).fillna(self.fillna).values
self.x_test = self.y.test.shift_panel(self.shift).fillna(self.fillna).values

self.y_train = self.y.train.values
self.y_val = self.y.val.values
Expand Down
4 changes: 2 additions & 2 deletions src/wavy/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def _copy_attrs(self, df):
def _constructor(self):
def f(*args, **kw):

index = [a for a in args[0].axes if isinstance(a, pd.MultiIndex)][0]
if len(index) == self.num_timesteps:
index = [a for a in args[0].axes if isinstance(a, pd.MultiIndex)]
if index and len(index[0]) == self.num_timesteps:
return pd.DataFrame(*args, **kw)

df = Panel(*args, **kw)
Expand Down
6 changes: 1 addition & 5 deletions src/wavy/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@ def add_annotation(self, panel, color="gray", opacity=1):
# BUG: Seems to break if using "ggplot2"
# ! Won't take effect until next trace is added (no axis was added)

ymax = max(
panel.train.max().max() if panel.train_size else 0,
panel.val.max().max() if panel.val_size else 0,
panel.test.max().max() if panel.test_size else 0,
)
ymax = panel.max().max() if panel.train_size else 0

if panel.train_size:
xtrain_min = panel.train.index[0]
Expand Down
Loading

0 comments on commit 9e6f6e5

Please sign in to comment.