Skip to content

Commit

Permalink
refac: add tutorial and fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ibiscp committed Jul 30, 2022
1 parent 9e6f6e5 commit b595d93
Show file tree
Hide file tree
Showing 3 changed files with 36,654 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/wavy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def predict_proba(self, data: Panel = None, **kwargs):

if data is not None:
x = data.values_panel
index = pd.MultiIndex.from_arrays([data.ids, data.first_timestamp])
index = pd.MultiIndex.from_arrays([data.ids, data.get_timesteps(0)])
else:
x = np.concatenate([self.x_train, self.x_val, self.x_test], axis=0)
index = pd.MultiIndex.from_tuples(
Expand Down
16 changes: 11 additions & 5 deletions src/wavy/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,12 @@ def _copy_attrs(self, df):
def _constructor(self):
def f(*args, **kw):

index = [a for a in args[0].axes if isinstance(a, pd.MultiIndex)]
if index and len(index[0]) == self.num_timesteps:
return pd.DataFrame(*args, **kw)
try:
index = [a for a in args[0].axes if isinstance(a, pd.MultiIndex)]
if index and len(index[0]) == self.num_timesteps:
return pd.DataFrame(*args, **kw)
except:
pass

df = Panel(*args, **kw)

Expand Down Expand Up @@ -268,15 +271,18 @@ def row_panel(self, n: int = 0):

return self.groupby(level=0, as_index=False).nth(n)

def get_timestep(self, n: int = 0):
def get_timesteps(self, n: Union[list, int] = 0):
"""
Returns the first timestep of each frame in the panel.
Args:
n (int): Timestep to return.
"""

return self.frames.take(n).index.get_level_values(1)
if isinstance(n, int):
n = [n]

return self.frames.take(n).index.get_level_values(2)

@property
def values_panel(self):
Expand Down
Loading

0 comments on commit b595d93

Please sign in to comment.