Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: brits imputation test device mismatch #11

Merged
merged 3 commits into from
Aug 22, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: locf mask and arange device mismatch
  • Loading branch information
MaciejSkrabski committed Aug 2, 2022
commit 97ae394df543a011d9875186535a5b4742ad974c
19 changes: 11 additions & 8 deletions pypots/imputation/locf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ class LOCF(BaseImputer):
"""

def __init__(self, nan=0):
super().__init__('cpu')
super().__init__("cpu")
self.nan = nan

def fit(self, train_X, val_X=None):
warnings.warn(
'LOCF (Last Observed Carried Forward) imputation class has no parameter to train. '
'Please run func impute(X) directly.'
"LOCF (Last Observed Carried Forward) imputation class has no parameter to train. "
"Please run func impute(X) directly."
)

def locf_numpy(self, X):
Expand Down Expand Up @@ -86,7 +86,7 @@ def locf_torch(self, X):
trans_X = X.permute((0, 2, 1))
mask = torch.isnan(trans_X)
n_samples, n_steps, n_features = mask.shape
idx = torch.where(~mask, torch.arange(n_features), 0)
idx = torch.where(~mask, torch.arange(n_features, device=mask.device), 0)
idx = torch.cummax(idx, dim=2)

collector = []
Expand Down Expand Up @@ -116,8 +116,10 @@ def impute(self, X):
array-like,
Imputed time series.
"""
assert len(X.shape) == 3, f'Input X should have 3 dimensions [n_samples, n_steps, n_features], ' \
f'but the actual shape of X: {X.shape}'
assert len(X.shape) == 3, (
f"Input X should have 3 dimensions [n_samples, n_steps, n_features], "
f"but the actual shape of X: {X.shape}"
)
if isinstance(X, list):
X = np.asarray(X)

Expand All @@ -126,6 +128,7 @@ def impute(self, X):
elif isinstance(X, torch.Tensor):
X_imputed = self.locf_torch(X).detach().cpu().numpy()
else:
raise TypeError('X must be type of list/np.ndarray/torch.Tensor, '
f'but got {type(X)}')
raise TypeError(
"X must be type of list/np.ndarray/torch.Tensor, " f"but got {type(X)}"
)
return X_imputed