Skip to content

Commit

Permalink
Merge branch 'table_benchmark_update' of https://github.com/Learnware…
Browse files Browse the repository at this point in the history
…-LAMDA/Learnware into table_benchmark_update
  • Loading branch information
liuht-0807 committed Jan 14, 2024
2 parents a020fcf + 80e58a7 commit 4258ded
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions examples/dataset_table_workflow/homo.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ def labeled_homo_table_example(self, skip_test=False):
recorders = {method: Recorder() for method in methods}
user = self.benchmark.name

if not skip_test:
for idx in range(self.benchmark.user_num):
test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
test_x, test_y = test_x.values, test_y.values

train_x, train_y = self.benchmark.get_train_data(user_ids=idx)
train_x, train_y = train_x.values, train_y.values
train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y)

if not skip_test:
for idx in range(self.benchmark.user_num):
test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
Expand All @@ -121,21 +130,39 @@ def labeled_homo_table_example(self, skip_test=False):
train_x, train_y = train_x.values, train_y.values
train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y)

user_stat_spec = generate_stat_spec(type="table", X=test_x)
user_info = BaseUserInfo(
semantic_spec=self.user_semantic, stat_info={"RKMETableSpecification": user_stat_spec}
)
logger.info(f"Searching Market for user: {user}_{idx}")
user_stat_spec = generate_stat_spec(type="table", X=test_x)
user_info = BaseUserInfo(
semantic_spec=self.user_semantic, stat_info={"RKMETableSpecification": user_stat_spec}
)
logger.info(f"Searching Market for user: {user}_{idx}")

search_result = self.market.search_learnware(user_info)
single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results()
search_result = self.market.search_learnware(user_info)
single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results()

logger.info(f"search result of user {user}_{idx}:")
logger.info(
f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
)
logger.info(f"search result of user {user}_{idx}:")
logger.info(
f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
)

if len(multiple_result) > 0:
mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
mixture_learnware_list = multiple_result[0].learnwares
else:
mixture_learnware_list = [single_result[0].learnware]
if len(multiple_result) > 0:
mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
Expand All @@ -150,6 +177,13 @@ def labeled_homo_table_example(self, skip_test=False):
"homo_single_aug": {"single_learnware": [single_result[0].learnware]},
"homo_ensemble_pruning": common_config
}
test_info = {"user": user, "idx": idx, "train_subsets": train_subsets, "test_x": test_x, "test_y": test_y}
common_config = {"learnwares": mixture_learnware_list}
method_configs = {
"user_model": {"dataset": self.benchmark.name, "model_type": "lgb"},
"homo_single_aug": {"single_learnware": [single_result[0].learnware]},
"homo_ensemble_pruning": common_config
}

for method_name in methods:
logger.info(f"Testing method {method_name}")
Expand All @@ -158,7 +192,17 @@ def labeled_homo_table_example(self, skip_test=False):
test_info.update(method_configs[method_name])
self.test_method(test_info, recorders, loss_func=loss_func_rmse)

for method, recorder in recorders.items():
recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json"))
for method_name in methods:
logger.info(f"Testing method {method_name}")
test_info["method_name"] = method_name
test_info["force"] = method_name in methods_to_retest
test_info.update(method_configs[method_name])
self.test_method(test_info, recorders, loss_func=loss_func_rmse)

for method, recorder in recorders.items():
recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json"))

plot_performance_curves(self.curves_result_path, user, recorders, task="Homo", n_labeled_list=homo_n_labeled_list)
plot_performance_curves(self.curves_result_path, user, recorders, task="Homo", n_labeled_list=homo_n_labeled_list)

0 comments on commit 4258ded

Please sign in to comment.