Skip to content

Commit

Permalink
Merge pull request #86 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Make each model a single package
  • Loading branch information
WenjieDu committed May 5, 2023
2 parents b573978 + 53f0533 commit 73c3469
Show file tree
Hide file tree
Showing 38 changed files with 1,620 additions and 1,153 deletions.
10 changes: 8 additions & 2 deletions pypots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,19 @@
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
__version__ = "0.0.11"

from pypots import classification
from pypots import clustering
from pypots import data
from pypots import forecasting
from pypots import imputation
from pypots import utils

__all__ = [
"data",
"imputation",
"classification",
"clustering",
"data",
"forecasting",
"imputation",
"utils",
"__version__",
]
12 changes: 12 additions & 0 deletions pypots/classification/brits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
"""

# Created by Wenjie Du <[email protected]>
# License: GLP-v3

from pypots.classification.brits.model import BRITS

__all__ = [
"BRITS",
]
48 changes: 48 additions & 0 deletions pypots/classification/brits/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Dataset class for model BRITS.
"""

# Created by Wenjie Du <[email protected]>
# License: GLP-v3

from typing import Union

from pypots.imputation.brits.dataset import (
DatasetForBRITS as DatasetForBRITS_Imputation,
)


class DatasetForBRITS(DatasetForBRITS_Imputation):
"""Dataset class for BRITS.
Parameters
----------
data : dict or str,
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for input, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
return_labels : bool, default = True,
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5
files, they already have both X and y saved. But we don't read labels from the file for validating and testing
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.
file_type : str, default = "h5py"
The type of the given file if train_set and val_set are path strings.
"""

def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
):
super().__init__(data, return_labels, file_type)
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
from torch.utils.data import DataLoader

from pypots.classification.base import BaseNNClassifier
from pypots.data import DatasetForBRITS
from pypots.imputation.brits import RITS as imputation_RITS, _BRITS as imputation_BRITS
from pypots.classification.brits.dataset import DatasetForBRITS
from pypots.imputation.brits.model import (
RITS as imputation_RITS,
_BRITS as imputation_BRITS,
)


class RITS(imputation_RITS):
Expand Down
12 changes: 12 additions & 0 deletions pypots/classification/grud/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
"""

# Created by Wenjie Du <[email protected]>
# License: GLP-v3

from pypots.classification.grud.model import GRUD

__all__ = [
"GRUD",
]
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torch.utils.data import DataLoader

from pypots.classification.base import BaseNNClassifier
from pypots.data.dataset_for_grud import DatasetForGRUD
from pypots.imputation.brits import TemporalDecay
from pypots.classification.grud.dataset import DatasetForGRUD
from pypots.imputation.brits.module import TemporalDecay


class _GRUD(nn.Module):
Expand Down
12 changes: 12 additions & 0 deletions pypots/classification/raindrop/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
"""

# Created by Wenjie Du <[email protected]>
# License: GLP-v3

from pypots.classification.raindrop.model import Raindrop

__all__ = [
"Raindrop",
]
47 changes: 47 additions & 0 deletions pypots/classification/raindrop/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Dataset class for model Raindrop.
"""

# Created by Wenjie Du <[email protected]>
# License: GLP-v3


from typing import Union

from pypots.classification.grud.dataset import DatasetForGRUD


class DatasetForRaindrop(DatasetForGRUD):
"""Dataset class for model GRU-D.
Parameters
----------
data : dict or str,
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for input, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
return_labels : bool, default = True,
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5
files, they already have both X and y saved. But we don't read labels from the file for validating and testing
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.
file_type : str, default = "h5py"
The type of the given file if train_set and val_set are path strings.
"""

def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
):
super().__init__(data, return_labels, file_type)
Loading

0 comments on commit 73c3469

Please sign in to comment.