Skip to content

Commit

Permalink
Merge pull request #292 from Learnware-LAMDA/offline_check
Browse files Browse the repository at this point in the history
feat(backend): verify learnware offline
  • Loading branch information
zouxiaochuan committed Dec 28, 2023
2 parents e0bc3ec + 66857fa commit 1452016
Show file tree
Hide file tree
Showing 16 changed files with 443 additions and 85 deletions.
2 changes: 2 additions & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,6 @@ def update(self, *args, **kwargs):
"user_agreement_file": "",
"privacy_policy_file": "",
"datasets_path": os.path.join(DATA_PATH, "datasets"),
"env_path": os.path.join(DATA_PATH, "envs"),
"learnware_checker_type": "conda"
}
1 change: 1 addition & 0 deletions backend/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def init_backend():
os.makedirs(config.backup_path, exist_ok=True)
os.makedirs(config.datasets_path, exist_ok=True)
os.makedirs(config.log_path, exist_ok=True)
os.makedirs(config.env_path, exist_ok=True)
pass


Expand Down
59 changes: 55 additions & 4 deletions backend/lib/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,74 @@

from learnware import market, specification
from learnware.market.heterogeneous import utils as learnware_utils
from learnware.learnware import Learnware
from learnware.learnware.utils import get_stat_spec_from_config
from learnware.config import C as learnware_config
from learnware.client.utils import install_environment, system_execute
from learnware.market import BaseChecker
from flask import jsonify, g
from datetime import datetime, timedelta
from collections import defaultdict
import functools
import os, json, time
import hashlib
import os
import json
import traceback
import tempfile
import zipfile
import learnware.config
import yaml
from learnware.learnware.utils import get_stat_spec_from_config
from learnware.config import C as learnware_config
from . import common_utils
from . import sensitive_words_utils
from . import kubernetes_utils
import uuid
import shutil
from typing import Tuple
import shortuuid


class OfflineChecker(market.BaseChecker):
def __init__(self, inner_checker_class_name, **kwargs):
self.inner_checker_class_name = inner_checker_class_name
super(OfflineChecker, self).__init__(**kwargs)

def __call__(self, learnware: Learnware) -> Tuple[int, str]:
# 1. install environment
env_root = context.config["env_path"]

with tempfile.TemporaryDirectory(prefix="env_", dir=env_root) as env_path:
try:
install_environment(learnware.get_dirpath(), None, conda_prefix=env_path)
# default environment did not install torch
system_execute(
args=[
"conda",
"run",
"--prefix",
env_path,
"--no-capture-output",
"python",
"-m",
"pip",
"install",
"torch",
]
)

# 2. check learnware with pod
check_status, message = kubernetes_utils.run_check(
env_path, learnware.get_dirpath(), self.inner_checker_class_name
)
except Exception as e:
check_status = False
message = f"kubernetes error: {e}"
pass

if check_status:
check_result = BaseChecker.USABLE_LEARNWARE
else:
check_result = BaseChecker.NONUSABLE_LEARNWARE

return check_result, message


def cache(seconds: int, maxsize: int = 128, typed: bool = False):
Expand Down
51 changes: 51 additions & 0 deletions backend/lib/kubernetes_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from kubernetes import client, config, utils
import yaml
import time
import shortuuid
import os


def run_check(env_path, learnware_path, checker_name):
config.load_incluster_config()

k8s_client = client.ApiClient()
v1 = client.CoreV1Api()

template_file = os.path.join("learnware-check.yaml")
with open(template_file) as fin:
template_content = fin.read()
pass

current_pod_name = os.environ["HOSTNAME"]
current_pod_image = v1.read_namespaced_pod(name=current_pod_name, namespace="learnware").spec.containers[0].image

pod_name = str(shortuuid.uuid()).lower()
template_content = template_content.replace("{{NAME}}", pod_name)
template_content = template_content.replace("{{LEARNWARE_PATH}}", learnware_path)
template_content = template_content.replace("{{CHECKER_NAME}}", checker_name)
template_content = template_content.replace("{{ENV_PATH}}", env_path)
template_content = template_content.replace("{{IMAGE}}", current_pod_image)

template_dict = yaml.safe_load(template_content)

pod = utils.create_from_dict(k8s_client, template_dict)[0]

# wait for pod end
while True:
pod = v1.read_namespaced_pod(name=pod.metadata.name, namespace="learnware")
print(f"Pod status: {pod.status.phase}")
if pod.status.phase not in ["Pending", "Running"]:
break

time.sleep(1)
pass

if pod.status.phase == "Succeeded":
v1.delete_namespaced_pod(name=pod.metadata.name, namespace="learnware")
return True, "Success"
else:
# read logs of the pod
logs = v1.read_namespaced_pod_log(name=pod.metadata.name, namespace="learnware")
v1.delete_namespaced_pod(name=pod.metadata.name, namespace="learnware")
return False, logs
pass
3 changes: 2 additions & 1 deletion backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ shortuuid
pysocks
redis
fast_pytorch_kmeans
concurrent-log-handler
concurrent-log-handler
kubernetes
20 changes: 15 additions & 5 deletions backend/scripts/monitor_learnware_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,17 @@
from lib import sensitive_words_utils


def verify_learnware_with_conda_checker(
def get_stat_checker():
checker_type = context.config["learnware_checker_type"]

if checker_type == "conda":
return CondaChecker(inner_checker=EasyStatChecker())
if checker_type == "kubernetes":
return engine_utils.OfflineChecker("EasyStatChecker")
pass


def verify_learnware_with_checker(
learnware_id: str, learnware_path: str, semantic_specification: dict
) -> Tuple[bool, str]:
verify_sucess = True
Expand Down Expand Up @@ -60,9 +70,9 @@ def verify_learnware_with_conda_checker(
command_output += "\n" + check_message

# check stat spec
stat_checker = CondaChecker(inner_checker=EasyStatChecker())
stat_checker = get_stat_checker()
check_result, check_message = stat_checker(learnware=learnware)
if verify_sucess and check_result == EasyStatChecker.INVALID_LEARNWARE:
if verify_sucess and check_result != BaseChecker.USABLE_LEARNWARE:
verify_sucess = False
command_output = "conda checker does not pass"
command_output += "\n" + check_message
Expand Down Expand Up @@ -136,7 +146,7 @@ def worker_process_func(q: queue.Queue, env: dict):
learnware_filename, extract_path, learnware_id, semantic_specification
)

verify_success, command_output = verify_learnware_with_conda_checker(
verify_success, command_output = verify_learnware_with_checker(
learnware_id, extract_path, semantic_specification
)
# the learnware my be deleted
Expand Down Expand Up @@ -170,7 +180,7 @@ def worker_process_func(q: queue.Queue, env: dict):
os.remove(learnware_zippath)
learnware.utils.zip_learnware_folder(learnware_dirpath, learnware_zippath)

verify_success, command_output = verify_learnware_with_conda_checker(
verify_success, command_output = verify_learnware_with_checker(
learnware_id, learnware_dirpath, semantic_specification
)
pass
Expand Down
36 changes: 36 additions & 0 deletions backend/scripts/verify_learnware_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""this script is executed in k8s pod
"""
import argparse

from learnware.market import EasyStatChecker, BaseChecker
from learnware.learnware import get_learnware_from_dirpath
import os
import json


def verify_learnware(learnware_path, checker_name):
"""verify learnware script"""
checker = eval(checker_name + "()")
semantic_path = os.path.join(learnware_path, "semantic_specification.json")
with open(semantic_path, "r") as f:
semantic_spec = json.load(f)
pass
learnware = get_learnware_from_dirpath("testid", semantic_spec=semantic_spec, learnware_dirpath=learnware_path)

result, message = checker(learnware)
# learnware package is not updated
# if result != BaseChecker.USABLE_LEARNWARE:
if result != 1:
raise RuntimeError(message)
pass


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Verify learnware script")
parser.add_argument("--learnware_path", type=str, help="learnware path")
parser.add_argument("--checker_name", type=str, help="checker name")

args = parser.parse_args()

verify_learnware(args.learnware_path, args.checker_name)
pass
69 changes: 69 additions & 0 deletions backend/tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import unittest
from scripts import main
import multiprocessing
import context
from context import config as C
import requests
import os
import shutil
from tests import common_test_operations as testops
import time
import restful.utils as utils
import hashlib


class TestDatasets(unittest.TestCase):
def setUpClass() -> None:
testops.cleanup_folder()
unittest.TestCase.setUpClass()
testops.set_config("datasets_path", os.path.join("tests"))
mp_context = multiprocessing.get_context("spawn")
TestDatasets.server_process = mp_context.Process(target=main.main)
TestDatasets.server_process.start()
testops.wait_port_open(C.listen_port, 10)
context.init_database()
testops.clear_db()
TestDatasets.username = "test"
TestDatasets.email = "test@localhost"
TestDatasets.password = "test"
testops.url_request(
"auth/register",
{
"username": TestDatasets.username,
"password": TestDatasets.password,
"email": TestDatasets.email,
"confirm_email": False,
},
)

def tearDownClass() -> None:
unittest.TestCase.tearDownClass()
TestDatasets.server_process.kill()
testops.cleanup_folder()
testops.reset_config()

def test_list(self):
headers = testops.login(TestDatasets.email, TestDatasets.password)
result = testops.url_request("datasets/list_datasets", headers=headers)

self.assertEqual(result["code"], 0)
self.assertIn("data/stat.json", result["data"]["datasets"])

pass

def test_download(
self,
):
result = testops.url_request(
"datasets/download_datasets", {"dataset": "data/stat.json"}, method="get", return_response=True
)

with open(os.path.join("tests", "data", "stat.json"), "rb") as fin:
file_content = fin.read()
pass
self.assertTrue(testops.check_bytes_same(result.content, file_content))
pass


if __name__ == "__main__":
unittest.main()
61 changes: 61 additions & 0 deletions backend/tests/test_kubernetes_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import context
import unittest
import multiprocessing
from scripts import main
from context import config as C
import tempfile
import zipfile
import common_test_operations as testops
from learnware.config import C as learnware_config
from learnware.client.utils import install_environment, system_execute
from lib import engine as engine_utils
from lib import kubernetes_utils
from scripts.monitor_learnware_verify import verify_learnware_with_conda_checker
import shortuuid


class TestKubernetesUtils(unittest.TestCase):
def setUpClass() -> None:
pass

def tearDownClass() -> None:
pass

def test_verify_valid_learnware(self):
learnware_path = os.path.join("tests", "data", "test_learnware.zip")
semantic_spec = testops.test_learnware_semantic_specification_table()
learnware_id = "testid"

with tempfile.TemporaryDirectory() as learnware_folder:
engine_utils.repack_learnware_folder(learnware_path, learnware_folder, learnware_id, semantic_spec)
env_path = str(shortuuid.uuid())
env_path = os.path.join(context.config["env_path"], env_path)
install_environment(learnware_folder, None, conda_prefix=env_path)
system_execute(
args=[
"conda",
"run",
"--prefix",
env_path,
"--no-capture-output",
"python",
"-m",
"pip",
"install",
"torch",
]
)

result, message = kubernetes_utils.run_check(env_path, learnware_folder, "EasyStatChecker")
os.system(f"conda env remove --prefix {env_path}")
self.assertTrue(result)

pass

def test_verify_invalid_learnware(self):
pass


if __name__ == "__main__":
unittest.main()
19 changes: 0 additions & 19 deletions backend/tests/test_monitor_learnware_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,25 +224,6 @@ def test_add_folder_learnware(self):
testops.delete_learnware(learnware_id, headers)
pass

def test_add_learnware_sensitive_words(
self,
):
headers = testops.login(TestMonitorLearnwareVerify.email, TestMonitorLearnwareVerify.password)
semantic_specification = testops.test_learnware_semantic_specification()
semantic_specification["Description"]["Values"] += "黄色小电影"
learnware_id = testops.add_test_learnware_unverified(
TestMonitorLearnwareVerify.email,
TestMonitorLearnwareVerify.password,
semantic_specification=semantic_specification,
)

status = self.wait_verify_end(learnware_id)

self.assertEqual(status, LearnwareVerifyStatus.FAIL.value)

testops.delete_learnware(learnware_id, headers)
pass


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 1452016

Please sign in to comment.