Skip to content

Commit

Permalink
[MNT] Rename the folder
Browse files Browse the repository at this point in the history
  • Loading branch information
Googol2002 committed Dec 8, 2023
1 parent bae46ec commit 476c0e2
Show file tree
Hide file tree
Showing 19 changed files with 353 additions and 353 deletions.
144 changes: 0 additions & 144 deletions examples/dataset_cifar_workflow/main.py

This file was deleted.

File renamed without changes.
216 changes: 216 additions & 0 deletions examples/dataset_image_workflow(old)/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import numpy as np
import torch
from tqdm import tqdm

from get_data import *
import os
import random

from learnware.specification import RKMEImageSpecification
from learnware.reuse.averaging import AveragingReuser
from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction
from learnware.learnware import Learnware
import time

from learnware.market import instantiate_learnware_market, BaseUserInfo
from learnware.market.easy import database_ops
from learnware.learnware import Learnware
import learnware.specification as specification
from learnware.logger import get_module_logger

from shutil import copyfile, rmtree
import zipfile

logger = get_module_logger("image_test", level="INFO")
origin_data_root = "./data/origin_data"
processed_data_root = "./data/processed_data"
tmp_dir = "./data/tmp"
learnware_pool_dir = "./data/learnware_pool"
dataset = "cifar10"
n_uploaders = 30
n_users = 20
n_classes = 10
data_root = os.path.join(origin_data_root, dataset)
data_save_root = os.path.join(processed_data_root, dataset)
user_save_root = os.path.join(data_save_root, "user")
uploader_save_root = os.path.join(data_save_root, "uploader")
model_save_root = os.path.join(data_save_root, "uploader_model")
os.makedirs(data_root, exist_ok=True)
os.makedirs(user_save_root, exist_ok=True)
os.makedirs(uploader_save_root, exist_ok=True)
os.makedirs(model_save_root, exist_ok=True)


semantic_specs = [
{
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {"Values": ["Classification"], "Type": "Class"},
"Library": {"Values": ["Pytorch"], "Type": "Class"},
"Scenario": {"Values": ["Business"], "Type": "Tag"},
"Description": {"Values": "", "Type": "String"},
"Name": {"Values": "learnware_1", "Type": "String"},
"Output": {"Dimension": 10},
}
]

user_semantic = {
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {"Values": ["Classification"], "Type": "Class"},
"Library": {"Values": ["Pytorch"], "Type": "Class"},
"Scenario": {"Values": ["Business"], "Type": "Tag"},
"Description": {"Values": "", "Type": "String"},
"Name": {"Values": "", "Type": "String"},
}


def prepare_data():
if dataset == "cifar10":
X_train, y_train, X_test, y_test = get_cifar10(data_root)
elif dataset == "mnist":
X_train, y_train, X_test, y_test = get_mnist(data_root)
else:
return
generate_uploader(X_train, y_train, n_uploaders=n_uploaders, data_save_root=uploader_save_root)
generate_user(X_test, y_test, n_users=n_users, data_save_root=user_save_root)


def prepare_model():
dataloader = ImageDataLoader(data_save_root, train=True)
for i in range(n_uploaders):
logger.info("Train on uploader: %d" % (i))
X, y = dataloader.get_idx_data(i)
model = train(X, y, out_classes=n_classes)
model_save_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
torch.save(model.state_dict(), model_save_path)
logger.info("Model saved to '%s'" % (model_save_path))


def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_root, zip_name):
os.makedirs(save_root, exist_ok=True)
tmp_spec_path = os.path.join(save_root, "rkme.json")
tmp_model_path = os.path.join(save_root, "conv_model.pth")
tmp_yaml_path = os.path.join(save_root, "learnware.yaml")
tmp_init_path = os.path.join(save_root, "__init__.py")
tmp_model_file_path = os.path.join(save_root, "model.py")
mmodel_file_path = "./example_files/model.py"

# Computing the specification from the whole dataset is too costly.
X = np.load(data_path)
indices = np.random.choice(len(X), size=2000, replace=False)
X_sampled = X[indices]

st = time.time()
user_spec = RKMEImageSpecification(cuda_idx=0)
user_spec.generate_stat_spec_from_data(X=X_sampled)
ed = time.time()
logger.info("Stat spec generated in %.3f s" % (ed - st))
user_spec.save(tmp_spec_path)
copyfile(model_path, tmp_model_path)
copyfile(yaml_path, tmp_yaml_path)
copyfile(init_file_path, tmp_init_path)
copyfile(mmodel_file_path, tmp_model_file_path)
zip_file_name = os.path.join(learnware_pool_dir, "%s.zip" % (zip_name))
with zipfile.ZipFile(zip_file_name, "w", compression=zipfile.ZIP_DEFLATED) as zip_obj:
zip_obj.write(tmp_spec_path, "rkme.json")
zip_obj.write(tmp_model_path, "conv_model.pth")
zip_obj.write(tmp_yaml_path, "learnware.yaml")
zip_obj.write(tmp_init_path, "__init__.py")
zip_obj.write(tmp_model_file_path, "model.py")
rmtree(save_root)
logger.info("New Learnware Saved to %s" % (zip_file_name))
return zip_file_name


def prepare_market():
image_market = instantiate_learnware_market(market_id="cifar10", name="easy", rebuild=True)
try:
rmtree(learnware_pool_dir)
except:
pass
os.makedirs(learnware_pool_dir, exist_ok=True)
for i in tqdm(range(n_uploaders), total=n_uploaders, desc="Preparing..."):
data_path = os.path.join(uploader_save_root, "uploader_%d_X.npy" % (i))
model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
init_file_path = "./example_files/example_init.py"
yaml_file_path = "./example_files/example_yaml.yaml"
new_learnware_path = prepare_learnware(
data_path, model_path, init_file_path, yaml_file_path, tmp_dir, "%s_%d" % (dataset, i)
)
semantic_spec = semantic_specs[0]
semantic_spec["Name"]["Values"] = "learnware_%d" % (i)
semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (i)
image_market.add_learnware(new_learnware_path, semantic_spec)

logger.info("Total Item: %d" % (len(image_market)))
curr_inds = image_market._get_ids()
logger.info("Available ids: " + str(curr_inds))


def test_search(gamma=0.1, load_market=True):
if load_market:
image_market = instantiate_learnware_market(market_id="cifar10", name="easy")
else:
prepare_market()
image_market = instantiate_learnware_market(market_id="cifar10", name="easy")
logger.info("Number of items in the market: %d" % len(image_market))

select_list = []
avg_list = []
improve_list = []
job_selector_score_list = []
ensemble_score_list = []
for i in tqdm(range(n_users), total=n_users, desc="Searching..."):
user_data_path = os.path.join(user_save_root, "user_%d_X.npy" % (i))
user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i))
user_data = np.load(user_data_path)
user_label = np.load(user_label_path)
user_stat_spec = RKMEImageSpecification(cuda_idx=0)
user_stat_spec.generate_stat_spec_from_data(X=user_data, resize=False)
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_stat_spec})
logger.info("Searching Market for user: %d" % i)
search_result = image_market.search_learnware(user_info)
single_result = search_result.get_single_results()
acc_list = []
for idx, single_item in enumerate(single_result[:5]):
pred_y = single_item.learnware.predict(user_data)
acc = eval_prediction(pred_y, user_label)
acc_list.append(acc)
logger.info("Search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, single_item.score, single_item.learnware.id, acc))

# test reuse (job selector)
# reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100)
# reuse_predict = reuse_baseline.predict(user_data=user_data)
# reuse_score = eval_prediction(reuse_predict, user_label)
# job_selector_score_list.append(reuse_score)
# print(f"mixture reuse loss: {reuse_score}")

# test reuse (ensemble)
single_learnware_list = [single_item.learnware for single_item in single_result]
reuse_ensemble = AveragingReuser(learnware_list=single_learnware_list[:3], mode="vote_by_prob")
ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
ensemble_score = eval_prediction(ensemble_predict_y, user_label)
ensemble_score_list.append(ensemble_score)
print(f"reuse accuracy (vote_by_prob): {ensemble_score}\n")

select_list.append(acc_list[0])
avg_list.append(np.mean(acc_list))
improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list))

logger.info(
"Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f"
% (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list))
)
logger.info(
"Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
)


if __name__ == "__main__":
logger.info("=" * 40)
logger.info(f"n_uploaders:\t{n_uploaders}")
logger.info(f"n_users:\t{n_users}")
logger.info("=" * 40)

prepare_data()
prepare_model()
test_search(load_market=False)
File renamed without changes.
Loading

0 comments on commit 476c0e2

Please sign in to comment.