Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make each model a single package #86

Merged
merged 1 commit into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
refactor: make each model a single package to standardize the whole l…
…ibrary for easier management;
  • Loading branch information
WenjieDu committed May 5, 2023
commit 53f053339c6cc1c819c1e26ac191bc463e212d25
10 changes: 8 additions & 2 deletions pypots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,19 @@
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
__version__ = "0.0.11"

from pypots import classification
from pypots import clustering
from pypots import data
from pypots import forecasting
from pypots import imputation
from pypots import utils

__all__ = [
"data",
"imputation",
"classification",
"clustering",
"data",
"forecasting",
"imputation",
"utils",
"__version__",
]
12 changes: 12 additions & 0 deletions pypots/classification/brits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""

"""

# Created by Wenjie Du <[email protected]>
# License: GLP-v3

from pypots.classification.brits.model import BRITS

__all__ = [
"BRITS",
]
48 changes: 48 additions & 0 deletions pypots/classification/brits/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Dataset class for model BRITS.
"""

# Created by Wenjie Du <[email protected]>
# License: GLP-v3

from typing import Union

from pypots.imputation.brits.dataset import (
DatasetForBRITS as DatasetForBRITS_Imputation,
)


class DatasetForBRITS(DatasetForBRITS_Imputation):
"""Dataset class for BRITS.

Parameters
----------
data : dict or str,
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for input, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

return_labels : bool, default = True,
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5
files, they already have both X and y saved. But we don't read labels from the file for validating and testing
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.

file_type : str, default = "h5py"
The type of the given file if train_set and val_set are path strings.
"""

def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
):
super().__init__(data, return_labels, file_type)
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
from torch.utils.data import DataLoader

from pypots.classification.base import BaseNNClassifier
from pypots.data import DatasetForBRITS
from pypots.imputation.brits import RITS as imputation_RITS, _BRITS as imputation_BRITS
from pypots.classification.brits.dataset import DatasetForBRITS
from pypots.imputation.brits.model import (
RITS as imputation_RITS,
_BRITS as imputation_BRITS,
)


class RITS(imputation_RITS):
Expand Down
12 changes: 12 additions & 0 deletions pypots/classification/grud/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""

"""

# Created by Wenjie Du <[email protected]>
# License: GLP-v3

from pypots.classification.grud.model import GRUD

__all__ = [
"GRUD",
]
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torch.utils.data import DataLoader

from pypots.classification.base import BaseNNClassifier
from pypots.data.dataset_for_grud import DatasetForGRUD
from pypots.imputation.brits import TemporalDecay
from pypots.classification.grud.dataset import DatasetForGRUD
from pypots.imputation.brits.module import TemporalDecay


class _GRUD(nn.Module):
Expand Down
12 changes: 12 additions & 0 deletions pypots/classification/raindrop/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""

"""

# Created by Wenjie Du <[email protected]>
# License: GLP-v3

from pypots.classification.raindrop.model import Raindrop

__all__ = [
"Raindrop",
]
47 changes: 47 additions & 0 deletions pypots/classification/raindrop/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Dataset class for model Raindrop.
"""

# Created by Wenjie Du <[email protected]>
# License: GLP-v3


from typing import Union

from pypots.classification.grud.dataset import DatasetForGRUD


class DatasetForRaindrop(DatasetForGRUD):
"""Dataset class for model GRU-D.

Parameters
----------
data : dict or str,
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for input, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.

return_labels : bool, default = True,
Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
during training of classification models, the Dataset class will return labels in __getitem__() for model input.
Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5
files, they already have both X and y saved. But we don't read labels from the file for validating and testing
with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
distinction.

file_type : str, default = "h5py"
The type of the given file if train_set and val_set are path strings.
"""

def __init__(
self,
data: Union[dict, str],
return_labels: bool = True,
file_type: str = "h5py",
):
super().__init__(data, return_labels, file_type)
Loading
Loading