-
-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #86 from WenjieDu/dev
Make each model a single package
- Loading branch information
Showing
38 changed files
with
1,620 additions
and
1,153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
""" | ||
|
||
# Created by Wenjie Du <[email protected]> | ||
# License: GLP-v3 | ||
|
||
from pypots.classification.brits.model import BRITS | ||
|
||
__all__ = [ | ||
"BRITS", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
""" | ||
Dataset class for model BRITS. | ||
""" | ||
|
||
# Created by Wenjie Du <[email protected]> | ||
# License: GLP-v3 | ||
|
||
from typing import Union | ||
|
||
from pypots.imputation.brits.dataset import ( | ||
DatasetForBRITS as DatasetForBRITS_Imputation, | ||
) | ||
|
||
|
||
class DatasetForBRITS(DatasetForBRITS_Imputation): | ||
"""Dataset class for BRITS. | ||
Parameters | ||
---------- | ||
data : dict or str, | ||
The dataset for model input, should be a dictionary including keys as 'X' and 'y', | ||
or a path string locating a data file. | ||
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], | ||
which is time-series data for input, can contain missing values, and y should be array-like of shape | ||
[n_samples], which is classification labels of X. | ||
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains | ||
key-value pairs like a dict, and it has to include keys as 'X' and 'y'. | ||
return_labels : bool, default = True, | ||
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example, | ||
during training of classification models, the Dataset class will return labels in __getitem__() for model input. | ||
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we | ||
need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5 | ||
files, they already have both X and y saved. But we don't read labels from the file for validating and testing | ||
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for | ||
distinction. | ||
file_type : str, default = "h5py" | ||
The type of the given file if train_set and val_set are path strings. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data: Union[dict, str], | ||
return_labels: bool = True, | ||
file_type: str = "h5py", | ||
): | ||
super().__init__(data, return_labels, file_type) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
""" | ||
|
||
# Created by Wenjie Du <[email protected]> | ||
# License: GLP-v3 | ||
|
||
from pypots.classification.grud.model import GRUD | ||
|
||
__all__ = [ | ||
"GRUD", | ||
] |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
""" | ||
|
||
# Created by Wenjie Du <[email protected]> | ||
# License: GLP-v3 | ||
|
||
from pypots.classification.raindrop.model import Raindrop | ||
|
||
__all__ = [ | ||
"Raindrop", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
Dataset class for model Raindrop. | ||
""" | ||
|
||
# Created by Wenjie Du <[email protected]> | ||
# License: GLP-v3 | ||
|
||
|
||
from typing import Union | ||
|
||
from pypots.classification.grud.dataset import DatasetForGRUD | ||
|
||
|
||
class DatasetForRaindrop(DatasetForGRUD): | ||
"""Dataset class for model GRU-D. | ||
Parameters | ||
---------- | ||
data : dict or str, | ||
The dataset for model input, should be a dictionary including keys as 'X' and 'y', | ||
or a path string locating a data file. | ||
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], | ||
which is time-series data for input, can contain missing values, and y should be array-like of shape | ||
[n_samples], which is classification labels of X. | ||
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains | ||
key-value pairs like a dict, and it has to include keys as 'X' and 'y'. | ||
return_labels : bool, default = True, | ||
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example, | ||
during training of classification models, the Dataset class will return labels in __getitem__() for model input. | ||
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we | ||
need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5 | ||
files, they already have both X and y saved. But we don't read labels from the file for validating and testing | ||
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for | ||
distinction. | ||
file_type : str, default = "h5py" | ||
The type of the given file if train_set and val_set are path strings. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data: Union[dict, str], | ||
return_labels: bool = True, | ||
file_type: str = "h5py", | ||
): | ||
super().__init__(data, return_labels, file_type) |
Oops, something went wrong.