Skip to content

Commit

Permalink
[ENH] Modify UserInfo Class
Browse files Browse the repository at this point in the history
  • Loading branch information
bxdd committed Apr 4, 2023
1 parent 3fe2e2a commit 1e5dc48
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 37 deletions.
Binary file modified docs/_static/img/classes.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 7 additions & 1 deletion docs/references/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ Market
.. autoclass:: learnware.market.BaseMarket
:members:

.. autoclass:: learnware.market.SerialMarket
:members:

.. autoclass:: learnware.market.AnchoredMarket
:members:

Expand All @@ -23,6 +26,9 @@ Market
.. autoclass:: learnware.market.BaseUserInfo
:members:

.. autoclass:: learnware.market.SerialUserInfo
:members:

Learnware
====================

Expand All @@ -41,5 +47,5 @@ Specification
.. autoclass:: learnware.specification.BaseStatSpecification
:members:

.. autoclass:: learnware.specification.RKMESpecification
.. autoclass:: learnware.specification.RKMEStatSpecification
:members:
4 changes: 2 additions & 2 deletions learnware/learnware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@


class Learnware:
def __init__(self, id: str, name: str, model: Union[BaseModel, dict], specification: Specification, desc: str):
def __init__(self, id: str, name: str, model: BaseModel, specification: Specification, desc: str):
self.id = id
self.name = name
self.model = self._import_model(model)
self.model = model
self.specification = specification
self.desc = desc

Expand Down
1 change: 1 addition & 0 deletions learnware/market/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base import BaseUserInfo, BaseMarket
from .anchor import AnchoredUserInfo, AnchoredMarket
from .evolve import EvolvedMarket
from .serial import SerialMarket, SerialUserInfo
5 changes: 3 additions & 2 deletions learnware/market/anchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from typing import Tuple, Any, List, Union, Dict

from ..learnware import Learnware
from .base import BaseUserInfo, BaseMarket
from .base import BaseMarket
from .serial import SerialUserInfo


class AnchoredUserInfo(BaseUserInfo):
class AnchoredUserInfo(SerialUserInfo):
"""
User Information for searching learnware (add the anchor design)
Expand Down
27 changes: 2 additions & 25 deletions learnware/market/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,11 @@
class BaseUserInfo:
"""User Information for searching learnware"""

def __init__(self, id: str, semantic_spec: dict = dict(), stat_info: dict = dict()):
"""Initializing user information
Parameters
----------
id : str
user id
semantic_spec : dict, optional
semantic_spec selected by user, by default dict()
stat_info : dict, optional
statistical information uploaded by user, by default dict()
"""
self.id = id
self.semantic_spec = semantic_spec
self.stat_info = stat_info

def get_semantic_spec(self) -> dict:
"""Return user semantic specifications
Returns
-------
dict
user semantic specifications
"""
return self.semantic_spec
raise NotImplementedError("get_semantic_spec is Not Implemented")

def get_stat_info(self, name: str):
return self.stat_info.get(name, None)
raise NotImplementedError("get_stat_info is Not Implemented")


class BaseMarket:
Expand Down
29 changes: 29 additions & 0 deletions learnware/market/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,35 @@
from ..learnware import Learnware
from ..specification import RKMEStatSpecification

class SerialUserInfo(BaseUserInfo):
def __init__(self, id: str, semantic_spec: dict = dict(), stat_info: dict = dict()):
"""Initializing user information
Parameters
----------
id : str
user id
semantic_spec : dict, optional
semantic_spec selected by user, by default dict()
stat_info : dict, optional
statistical information uploaded by user, by default dict()
"""
self.id = id
self.semantic_spec = semantic_spec
self.stat_info = stat_info

def get_semantic_spec(self) -> dict:
"""Return user semantic specifications
Returns
-------
dict
user semantic specifications
"""
return self.semantic_spec

def get_stat_info(self, name: str):
return self.stat_info.get(name, None)

class SerialMarket(BaseMarket):
def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion learnware/specification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def load(self, filepath: str):
class Specification:
def __init__(self, semantic_spec=None):
self.semantic_spec = semantic_spec
self.stat_spec = {} # stat_spec should be dict
self.stat_spec = BaseStatSpecification() # stat_spec should be dict

def get_stat_spec(self):
return self.stat_spec
Expand Down
12 changes: 6 additions & 6 deletions learnware/specification/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from .base import BaseStatSpecification
from .rkme import RKMESpecification
from .rkme import RKMEStatSpecification


def generate_rkme_spec(
Expand All @@ -13,10 +13,10 @@ def generate_rkme_spec(
nonnegative_beta: bool = True,
reduce: bool = True,
cuda_idx: int = -1,
) -> RKMESpecification:
) -> RKMEStatSpecification:
"""
Interface for users to generate Reduced-set Kernel Mean Embedding (RKME) specification.
Return a RKMESpecification object, use .save() method to save as json file.
Return a RKMEStatSpecification object, use .save() method to save as json file.
Parameters
Expand All @@ -41,10 +41,10 @@ def generate_rkme_spec(
Returns
-------
RKMESpecification
A RKMESpecification object
RKMEStatSpecification
A RKMEStatSpecification object
"""
rkme_spec = RKMESpecification(gamma=gamma, cuda_idx=cuda_idx)
rkme_spec = RKMEStatSpecification(gamma=gamma, cuda_idx=cuda_idx)
rkme_spec.generate_stat_spec_from_data(X, K, step_size, steps, nonnegative_beta, reduce)
return rkme_spec

Expand Down

0 comments on commit 1e5dc48

Please sign in to comment.