Skip to content

Commit

Permalink
[MNT] enable benchmarks pass the mypy test
Browse files Browse the repository at this point in the history
  • Loading branch information
GeneLiuXe committed Jan 12, 2024
1 parent c4036ba commit cd92ac2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
20 changes: 15 additions & 5 deletions learnware/tests/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickle
import tempfile
import zipfile
import numpy as np
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

Expand All @@ -21,7 +22,9 @@ class Benchmark:
train_y_paths: Optional[List[str]] = None
extra_info_path: Optional[str] = None

def get_test_data(self, user_ids: Union[int, List[int]]):
def get_test_data(
self, user_ids: Union[int, List[int]]
) -> Union[Tuple[np.ndarray, np.ndarray], List[Tuple[np.ndarray, np.ndarray]]]:
raw_user_ids = user_ids
if isinstance(user_ids, int):
user_ids = [user_ids]
Expand All @@ -41,7 +44,9 @@ def get_test_data(self, user_ids: Union[int, List[int]]):
else:
return ret

def get_train_data(self, user_ids: Union[int, List[int]]):
def get_train_data(
self, user_ids: Union[int, List[int]]
) -> Optional[Union[Tuple[np.ndarray, np.ndarray], List[Tuple[np.ndarray, np.ndarray]]]]:
if self.train_X_paths is None or self.train_y_paths is None:
return None

Expand Down Expand Up @@ -98,7 +103,7 @@ def _check_cache_data_valid(self, benchmark_config: BenchmarkConfig, data_type:
else:
return False

def _download_data(self, download_path: str, save_path: str):
def _download_data(self, download_path: str, save_path: str) -> None:
"""Download data from backend
Parameters
Expand Down Expand Up @@ -128,7 +133,7 @@ def _load_cache_data(self, benchmark_config: BenchmarkConfig, data_type: str) ->
"""
cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data")
if not self._check_cache_data_valid(benchmark_config, data_type):
download_path = getattr(benchmark_config, f"{data_type}_data_path", None)
download_path = getattr(benchmark_config, f"{data_type}_data_path")
self._download_data(download_path, cache_folder)

X_paths, y_paths = [], []
Expand All @@ -142,10 +147,15 @@ def _load_cache_data(self, benchmark_config: BenchmarkConfig, data_type: str) ->

return X_paths, y_paths

def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]):
def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]) -> Benchmark:
if isinstance(benchmark_config, str):
benchmark_config = self.benchmark_configs[benchmark_config]

if not isinstance(benchmark_config, BenchmarkConfig):
raise ValueError(
"benchmark_config must be a BenchmarkConfig object or a string in benchmark_configs.keys()!"
)

# Load test data
test_X_paths, test_y_paths = self._load_cache_data(benchmark_config, "test")

Expand Down
4 changes: 2 additions & 2 deletions learnware/tests/benchmarks/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Dict


@dataclass
Expand All @@ -12,4 +12,4 @@ class BenchmarkConfig:
extra_info_path: Optional[str] = None


benchmark_configs = {}
benchmark_configs: Dict[str, BenchmarkConfig] = {}

0 comments on commit cd92ac2

Please sign in to comment.