Skip to content

Commit

Permalink
Merge branch 'table_benchmark_update' of https://github.com/Learnware…
Browse files Browse the repository at this point in the history
…-LAMDA/Learnware into table_benchmark_update
  • Loading branch information
liuht-0807 committed Jan 14, 2024
2 parents b456b64 + 04f98ff commit dd93185
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/dataset_table_workflow/homo.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def labeled_homo_table_example(self, skip_test=False):
for idx in range(self.benchmark.user_num):
test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
test_x, test_y = test_x.values, test_y.values

train_x, train_y = self.benchmark.get_train_data(user_ids=idx)
train_x, train_y = train_x.values, train_y.values
train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y)
Expand Down Expand Up @@ -155,7 +155,7 @@ def labeled_homo_table_example(self, skip_test=False):
test_info["method_name"] = method_name
test_info.update(method_configs[method_name])
self.test_method(test_info, recorders, loss_func=loss_func_rmse)

for method, recorder in recorders.items():
recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json"))

Expand Down

0 comments on commit dd93185

Please sign in to comment.