Source code for homura.register

import contextlib
import functools
import types
from pathlib import Path
from typing import Dict, Optional, Type, TypeVar

T = TypeVar("T")


[docs]class Registry(object): """ Registry of models, datasets and anything you like. :: model_registry = Registry('model') @model_registry.register def your_model(*args, **kwargs): return ... your_model_instance = model_registry('your_model')(...) model_registry2 = Registry('model') model_registry is model_registry2 :param name: name of registry. If name is already used, return that registry. :param type: type of registees. If type is not None, registees are type checked when registered. """ _available_registries = {} def __new__(cls, name: str, type: Optional[Type[T]] = None ): if name in Registry._available_registries: return Registry._available_registries[name] return object.__new__(cls) def __init__(self, name: str, type: Optional[Type[T]] = None): self.name = name Registry._available_registries[name] = self self.type = type self._registry = {}
[docs] def register_from_dict(self, name_to_func: Dict[str, T]): for k, v in name_to_func.items(): self.register(v, name=k)
[docs] def register(self, func: T = None, *, name: Optional[str] = None ) -> T: if func is None: return functools.partial(self.register, name=name) _type = self.type if _type is not None and not isinstance(func, types.FunctionType): if not (isinstance(func, _type) or issubclass(func, _type)): raise TypeError(f'`func` is expected to be subclass of {_type}.') if name is None: name = func.__name__ if self._registry.get(name) is not None: raise KeyError(f'Name {name} is already used, try another name!') self._registry[name] = func return func
def __call__(self, name: str, *args, **kwargs): ret = self._registry.get(name) if ret is None: _registry = {k.lower(): v for k, v in self._registry.items()} ret = _registry.get(name.lower()) if ret is None: raise KeyError(f'Unknown {name} is called!') if len(args) > 0 or len(kwargs) > 0: # if args and kwargs is specified, instantiate ret = ret(*args, **kwargs) return ret
[docs] @classmethod def available_registries(cls, detailed: bool = False): if detailed: for k, v in cls._available_registries.items(): v.catalogue() else: print(list(cls._available_registries.keys()))
[docs] def help(self): print(f"{self.name} {list(self._registry.keys())}")
[docs] def choices(self): return tuple(self._registry.keys())
[docs] @staticmethod def import_modules(package_name: str) -> None: import importlib, pkgutil importlib.invalidate_caches() with Registry._push_python_path("."): module = importlib.import_module(package_name) path = getattr(module, '__path__', []) path_string = "" if not path else path[0] for module_finder, name, _ in pkgutil.walk_packages(path): if path_string and module_finder.path != path_string: continue sub_package = f'{package_name}.{name}' Registry.import_modules(sub_package)
@staticmethod @contextlib.contextmanager def _push_python_path(path: str): # https://github.com/allenai/allennlp/blob/v1.0.0/allennlp/common/util.py import sys path = Path(path).resolve() sys.path.insert(0, str(path)) try: yield finally: sys.path.remove(str(path))