Skip to content

Commit

Permalink
[MNT] Modify details
Browse files Browse the repository at this point in the history
  • Loading branch information
bxdd committed Apr 3, 2023
1 parent a793913 commit eb56e9d
Show file tree
Hide file tree
Showing 13 changed files with 212 additions and 201 deletions.
Binary file modified classes.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/img/classes.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions examples/examples1/example_rkme.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
spec2 = specification.rkme.RKMESpecification()
spec1.generate_stat_spec_from_data(data_X)
spec1.save("spec.json")

beta = spec1.get_beta()
z = spec1.get_z()
print(type(beta), beta.shape)
print(type(z), z.shape)

spec2.load("spec.json")
beta = spec1.get_beta()
z = spec1.get_z()
print(type(beta), beta.shape)
print(type(z), z.shape)

print(spec1.inner_prod(spec2))
print(spec1.dist(spec2))
print(spec1.dist(spec2))
23 changes: 7 additions & 16 deletions examples/examples2/example_learnware.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,26 @@ def prepare_learnware():
clf = svm.SVC()
clf.fit(data_X, data_y)
joblib.dump(clf, "./svm/svm.pkl")
spec= specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)

spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)
spec.save("./svm/spec.json")


def test_API():
text_X = np.random.randn(100, 20)
text_X = np.random.randn(100, 20)
svm = SVM()
pred_y1 = svm.predict(text_X)
print(type(svm))

model = {
"module_path": "./svm/__init__.py",
"class_name": "SVM"
}

model = {"module_path": "./svm/__init__.py", "class_name": "SVM"}
spec = specification.rkme.RKMESpecification()
spec.load("./svm/spec.json")
learnware = Learnware(
id="A0",
name="SVM",
model=model,
specification=spec,
desc="svm"
)
learnware = Learnware(id="A0", name="SVM", model=model, specification=spec, desc="svm")
pred_y2 = learnware.predict(text_X)
print(type(learnware.model))
print(f"diff: {np.sum(pred_y1 != pred_y2)}")


if __name__ == "__main__":
prepare_learnware()
test_API()
test_API()
4 changes: 2 additions & 2 deletions examples/examples2/svm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class SVM:
def __init__(self):
dir_path = os.path.dirname(os.path.abspath(__file__))
self.model = joblib.load(os.path.join(dir_path, 'svm.pkl'))
self.model = joblib.load(os.path.join(dir_path, "svm.pkl"))

def fit(self, X: np.ndarray, y: np.ndarray):
pass
Expand All @@ -15,4 +15,4 @@ def predict(self, X: np.ndarray) -> np.ndarray:
return self.model.predict(X)

def fintune(self, X: np.ndarray, y: np.ndarray):
pass
pass
6 changes: 3 additions & 3 deletions learnware/learnware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _import_model(self, model: Union[BaseModel, dict]) -> BaseModel:

def predict(self, X: np.ndarray) -> np.ndarray:
return self.model.predict(X)

def get_model(self) -> BaseModel:
return self.model

Expand All @@ -57,10 +57,10 @@ def get_specification(self) -> Specification:

def get_info(self):
return self.desc

def update_stat_spec(self, name, new_stat_spec: BaseStatSpecification):
self.specification.update_stat_spec(name, new_stat_spec)

def update(self):
# Empty Interface.
raise NotImplementedError("'update' Method is NOT Implemented.")
2 changes: 1 addition & 1 deletion learnware/market/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .base import BaseUserInfo, BaseMarket
from .anchor import AnchoredUserInfo, AnchoredMarket
from .evolve import EvolvedMarket
from .evolve import EvolvedMarket
18 changes: 9 additions & 9 deletions learnware/market/anchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

class AnchoredUserInfo(BaseUserInfo):
"""
User Information for searching learnware (add the anchor design)
- UserInfo contains the anchor list acquired from the market
- UserInfo can update stat_info based on anchors
User Information for searching learnware (add the anchor design)
- UserInfo contains the anchor list acquired from the market
- UserInfo can update stat_info based on anchors
"""

def __init__(self, id: str, property: dict = dict(), stat_info: dict = dict()):
super(AnchoredUserInfo, self).__init__(id, property, stat_info)
def __init__(self, id: str, semantic_spec: dict = dict(), stat_info: dict = dict()):
super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info)
self.anchor_learnware_list = {} # id: Learnware

def add_anchor_learnware(self, learnware_id: str, learnware: Learnware):
Expand Down Expand Up @@ -79,7 +79,7 @@ def _delete_anchor_learnware(self, anchor_id: str) -> bool:
-------
bool
True if the target anchor learnware is deleted successfully.
Raises
------
Exception
Expand Down Expand Up @@ -107,7 +107,7 @@ def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, Lis
Parameters
----------
user_info : AnchoredUserInfo
- user_info with properties and statistical information
- user_info with semantic specifications and statistical information
- some statistical information calculated on previous anchor learnwares
Returns
Expand All @@ -126,7 +126,7 @@ def search_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learn
Parameters
----------
user_info : AnchoredUserInfo
- user_info with properties and statistical information
- user_info with semantic specifications and statistical information
- some statistical information calculated on anchor learnwares
Returns
Expand Down
Loading

0 comments on commit eb56e9d

Please sign in to comment.