-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bae46ec
commit 476c0e2
Showing
19 changed files
with
353 additions
and
353 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
import numpy as np | ||
import torch | ||
from tqdm import tqdm | ||
|
||
from get_data import * | ||
import os | ||
import random | ||
|
||
from learnware.specification import RKMEImageSpecification | ||
from learnware.reuse.averaging import AveragingReuser | ||
from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction | ||
from learnware.learnware import Learnware | ||
import time | ||
|
||
from learnware.market import instantiate_learnware_market, BaseUserInfo | ||
from learnware.market.easy import database_ops | ||
from learnware.learnware import Learnware | ||
import learnware.specification as specification | ||
from learnware.logger import get_module_logger | ||
|
||
from shutil import copyfile, rmtree | ||
import zipfile | ||
|
||
logger = get_module_logger("image_test", level="INFO") | ||
origin_data_root = "./data/origin_data" | ||
processed_data_root = "./data/processed_data" | ||
tmp_dir = "./data/tmp" | ||
learnware_pool_dir = "./data/learnware_pool" | ||
dataset = "cifar10" | ||
n_uploaders = 30 | ||
n_users = 20 | ||
n_classes = 10 | ||
data_root = os.path.join(origin_data_root, dataset) | ||
data_save_root = os.path.join(processed_data_root, dataset) | ||
user_save_root = os.path.join(data_save_root, "user") | ||
uploader_save_root = os.path.join(data_save_root, "uploader") | ||
model_save_root = os.path.join(data_save_root, "uploader_model") | ||
os.makedirs(data_root, exist_ok=True) | ||
os.makedirs(user_save_root, exist_ok=True) | ||
os.makedirs(uploader_save_root, exist_ok=True) | ||
os.makedirs(model_save_root, exist_ok=True) | ||
|
||
|
||
semantic_specs = [ | ||
{ | ||
"Data": {"Values": ["Tabular"], "Type": "Class"}, | ||
"Task": {"Values": ["Classification"], "Type": "Class"}, | ||
"Library": {"Values": ["Pytorch"], "Type": "Class"}, | ||
"Scenario": {"Values": ["Business"], "Type": "Tag"}, | ||
"Description": {"Values": "", "Type": "String"}, | ||
"Name": {"Values": "learnware_1", "Type": "String"}, | ||
"Output": {"Dimension": 10}, | ||
} | ||
] | ||
|
||
user_semantic = { | ||
"Data": {"Values": ["Tabular"], "Type": "Class"}, | ||
"Task": {"Values": ["Classification"], "Type": "Class"}, | ||
"Library": {"Values": ["Pytorch"], "Type": "Class"}, | ||
"Scenario": {"Values": ["Business"], "Type": "Tag"}, | ||
"Description": {"Values": "", "Type": "String"}, | ||
"Name": {"Values": "", "Type": "String"}, | ||
} | ||
|
||
|
||
def prepare_data(): | ||
if dataset == "cifar10": | ||
X_train, y_train, X_test, y_test = get_cifar10(data_root) | ||
elif dataset == "mnist": | ||
X_train, y_train, X_test, y_test = get_mnist(data_root) | ||
else: | ||
return | ||
generate_uploader(X_train, y_train, n_uploaders=n_uploaders, data_save_root=uploader_save_root) | ||
generate_user(X_test, y_test, n_users=n_users, data_save_root=user_save_root) | ||
|
||
|
||
def prepare_model(): | ||
dataloader = ImageDataLoader(data_save_root, train=True) | ||
for i in range(n_uploaders): | ||
logger.info("Train on uploader: %d" % (i)) | ||
X, y = dataloader.get_idx_data(i) | ||
model = train(X, y, out_classes=n_classes) | ||
model_save_path = os.path.join(model_save_root, "uploader_%d.pth" % (i)) | ||
torch.save(model.state_dict(), model_save_path) | ||
logger.info("Model saved to '%s'" % (model_save_path)) | ||
|
||
|
||
def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_root, zip_name): | ||
os.makedirs(save_root, exist_ok=True) | ||
tmp_spec_path = os.path.join(save_root, "rkme.json") | ||
tmp_model_path = os.path.join(save_root, "conv_model.pth") | ||
tmp_yaml_path = os.path.join(save_root, "learnware.yaml") | ||
tmp_init_path = os.path.join(save_root, "__init__.py") | ||
tmp_model_file_path = os.path.join(save_root, "model.py") | ||
mmodel_file_path = "./example_files/model.py" | ||
|
||
# Computing the specification from the whole dataset is too costly. | ||
X = np.load(data_path) | ||
indices = np.random.choice(len(X), size=2000, replace=False) | ||
X_sampled = X[indices] | ||
|
||
st = time.time() | ||
user_spec = RKMEImageSpecification(cuda_idx=0) | ||
user_spec.generate_stat_spec_from_data(X=X_sampled) | ||
ed = time.time() | ||
logger.info("Stat spec generated in %.3f s" % (ed - st)) | ||
user_spec.save(tmp_spec_path) | ||
copyfile(model_path, tmp_model_path) | ||
copyfile(yaml_path, tmp_yaml_path) | ||
copyfile(init_file_path, tmp_init_path) | ||
copyfile(mmodel_file_path, tmp_model_file_path) | ||
zip_file_name = os.path.join(learnware_pool_dir, "%s.zip" % (zip_name)) | ||
with zipfile.ZipFile(zip_file_name, "w", compression=zipfile.ZIP_DEFLATED) as zip_obj: | ||
zip_obj.write(tmp_spec_path, "rkme.json") | ||
zip_obj.write(tmp_model_path, "conv_model.pth") | ||
zip_obj.write(tmp_yaml_path, "learnware.yaml") | ||
zip_obj.write(tmp_init_path, "__init__.py") | ||
zip_obj.write(tmp_model_file_path, "model.py") | ||
rmtree(save_root) | ||
logger.info("New Learnware Saved to %s" % (zip_file_name)) | ||
return zip_file_name | ||
|
||
|
||
def prepare_market(): | ||
image_market = instantiate_learnware_market(market_id="cifar10", name="easy", rebuild=True) | ||
try: | ||
rmtree(learnware_pool_dir) | ||
except: | ||
pass | ||
os.makedirs(learnware_pool_dir, exist_ok=True) | ||
for i in tqdm(range(n_uploaders), total=n_uploaders, desc="Preparing..."): | ||
data_path = os.path.join(uploader_save_root, "uploader_%d_X.npy" % (i)) | ||
model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i)) | ||
init_file_path = "./example_files/example_init.py" | ||
yaml_file_path = "./example_files/example_yaml.yaml" | ||
new_learnware_path = prepare_learnware( | ||
data_path, model_path, init_file_path, yaml_file_path, tmp_dir, "%s_%d" % (dataset, i) | ||
) | ||
semantic_spec = semantic_specs[0] | ||
semantic_spec["Name"]["Values"] = "learnware_%d" % (i) | ||
semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (i) | ||
image_market.add_learnware(new_learnware_path, semantic_spec) | ||
|
||
logger.info("Total Item: %d" % (len(image_market))) | ||
curr_inds = image_market._get_ids() | ||
logger.info("Available ids: " + str(curr_inds)) | ||
|
||
|
||
def test_search(gamma=0.1, load_market=True): | ||
if load_market: | ||
image_market = instantiate_learnware_market(market_id="cifar10", name="easy") | ||
else: | ||
prepare_market() | ||
image_market = instantiate_learnware_market(market_id="cifar10", name="easy") | ||
logger.info("Number of items in the market: %d" % len(image_market)) | ||
|
||
select_list = [] | ||
avg_list = [] | ||
improve_list = [] | ||
job_selector_score_list = [] | ||
ensemble_score_list = [] | ||
for i in tqdm(range(n_users), total=n_users, desc="Searching..."): | ||
user_data_path = os.path.join(user_save_root, "user_%d_X.npy" % (i)) | ||
user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i)) | ||
user_data = np.load(user_data_path) | ||
user_label = np.load(user_label_path) | ||
user_stat_spec = RKMEImageSpecification(cuda_idx=0) | ||
user_stat_spec.generate_stat_spec_from_data(X=user_data, resize=False) | ||
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_stat_spec}) | ||
logger.info("Searching Market for user: %d" % i) | ||
search_result = image_market.search_learnware(user_info) | ||
single_result = search_result.get_single_results() | ||
acc_list = [] | ||
for idx, single_item in enumerate(single_result[:5]): | ||
pred_y = single_item.learnware.predict(user_data) | ||
acc = eval_prediction(pred_y, user_label) | ||
acc_list.append(acc) | ||
logger.info("Search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, single_item.score, single_item.learnware.id, acc)) | ||
|
||
# test reuse (job selector) | ||
# reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) | ||
# reuse_predict = reuse_baseline.predict(user_data=user_data) | ||
# reuse_score = eval_prediction(reuse_predict, user_label) | ||
# job_selector_score_list.append(reuse_score) | ||
# print(f"mixture reuse loss: {reuse_score}") | ||
|
||
# test reuse (ensemble) | ||
single_learnware_list = [single_item.learnware for single_item in single_result] | ||
reuse_ensemble = AveragingReuser(learnware_list=single_learnware_list[:3], mode="vote_by_prob") | ||
ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) | ||
ensemble_score = eval_prediction(ensemble_predict_y, user_label) | ||
ensemble_score_list.append(ensemble_score) | ||
print(f"reuse accuracy (vote_by_prob): {ensemble_score}\n") | ||
|
||
select_list.append(acc_list[0]) | ||
avg_list.append(np.mean(acc_list)) | ||
improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list)) | ||
|
||
logger.info( | ||
"Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f" | ||
% (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list)) | ||
) | ||
logger.info( | ||
"Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list)) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
logger.info("=" * 40) | ||
logger.info(f"n_uploaders:\t{n_uploaders}") | ||
logger.info(f"n_users:\t{n_users}") | ||
logger.info("=" * 40) | ||
|
||
prepare_data() | ||
prepare_model() | ||
test_search(load_market=False) |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.