diff --git a/docs/en/user_guides/index.rst b/docs/en/user_guides/index.rst index 0a9582a4c7d..7986451893b 100644 --- a/docs/en/user_guides/index.rst +++ b/docs/en/user_guides/index.rst @@ -32,3 +32,4 @@ Useful Tools visualization.md robustness_benchmarking.md deploy.md + label_studio.md diff --git a/docs/en/user_guides/label_studio.md b/docs/en/user_guides/label_studio.md new file mode 100644 index 00000000000..07a1e84a2e2 --- /dev/null +++ b/docs/en/user_guides/label_studio.md @@ -0,0 +1,256 @@ +# Semi-automatic Object Detection Annotation with MMDetection and Label-Studio + +Annotation data is a time-consuming and laborious task. This article introduces how to perform semi-automatic annotation using the RTMDet algorithm in MMDetection in conjunction with Label-Studio software. Specifically, using RTMDet to predict image annotations and then refining the annotations with Label-Studio. Community users can refer to this process and methodology and apply it to other fields. + +- RTMDet: RTMDet is a high-precision single-stage object detection algorithm developed by OpenMMLab, open-sourced in the MMDetection object detection toolbox. Its open-source license is Apache 2.0, and it can be used freely without restrictions by industrial users. + +- [Label Studio](https://github.com/heartexlabs/label-studio) is an excellent annotation software covering the functionality of dataset annotation in areas such as image classification, object detection, and segmentation. + +In this article, we will use [cat](https://download.openmmlab.com/mmyolo/data/cat_dataset.zip) images for semi-automatic annotation. + +## Environment Configuration + +To begin with, you need to create a virtual environment and then install PyTorch and MMCV. In this article, we will specify the versions of PyTorch and MMCV. Next, you can install MMDetection, Label-Studio, and label-studio-ml-backend using the following steps: + +Create a virtual environment: + +```shell +conda create -n rtmdet python=3.9 -y +conda activate rtmdet +``` + +Install PyTorch: + +```shell +# Linux and Windows CPU only +pip install torch==1.10.1+cpu torchvision==0.11.2+cpu torchaudio==0.10.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html +# Linux and Windows CUDA 11.3 +pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu113/torch_stable.html +# OSX +pip install torch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 +``` + +Install MMCV: + +```shell +pip install -U openmim +mim install "mmcv>=2.0.0rc0" +# Installing mmcv will automatically install mmengine +``` + +Install MMDetection: + +```shell +git clone https://github.com/open-mmlab/mmdetection -b dev-3.x +cd mmdetection +pip install -v -e . +``` + +Install Label-Studio and label-studio-ml-backend: + +```shell +# Installing Label-Studio may take some time, if the version is not found, please use the official source +pip install label-studio==1.7.2 +pip install label-studio-ml==1.0.9 +``` + +Download the rtmdet weights: + +```shell +cd path/to/mmetection +mkdir work_dirs +cd work_dirs +wget https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth +``` + +## Start the Service + +Start the RTMDet backend inference service: + +```shell +cd path/to/mmetection + +label-studio-ml start projects/LabelStudio/backend_template --with \ +config_file=configs/rtmdet/rtmdet_m_8xb32-300e_coco.py \ +checkpoint_file=./work_dirs/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth \ +device=cpu \ +--port 8003 +# Set device=cpu to use CPU inference, and replace cpu with cuda:0 to use GPU inference. +``` + +![](https://cdn.vansin.top/picgo20230330131601.png) + +The RTMDet backend inference service has now been started. To configure it in the Label-Studio web system, use http://localhost:8003 as the backend inference service. + +Now, start the Label-Studio web service: + +```shell +label-studio start +``` + +![](https://cdn.vansin.top/picgo20230330132913.png) + +Open your web browser and go to http://localhost:8080/ to see the Label-Studio interface. + +![](https://cdn.vansin.top/picgo20230330133118.png) + +Register a user and then create an RTMDet-Semiautomatic-Label project. + +![](https://cdn.vansin.top/picgo20230330133333.png) + +Download the example cat images by running the following command and import them using the Data Import button: + +```shell +cd path/to/mmetection +mkdir data && cd data + +wget https://download.openmmlab.com/mmyolo/data/cat_dataset.zip && unzip cat_dataset.zip +``` + +![](https://cdn.vansin.top/picgo20230330133628.png) + +![](https://cdn.vansin.top/picgo20230330133715.png) + +Then, select the Object Detection With Bounding Boxes template. + +![](https://cdn.vansin.top/picgo20230330133807.png) + +```shell +airplane +apple +backpack +banana +baseball_bat +baseball_glove +bear +bed +bench +bicycle +bird +boat +book +bottle +bowl +broccoli +bus +cake +car +carrot +cat +cell_phone +chair +clock +couch +cow +cup +dining_table +dog +donut +elephant +fire_hydrant +fork +frisbee +giraffe +hair_drier +handbag +horse +hot_dog +keyboard +kite +knife +laptop +microwave +motorcycle +mouse +orange +oven +parking_meter +person +pizza +potted_plant +refrigerator +remote +sandwich +scissors +sheep +sink +skateboard +skis +snowboard +spoon +sports_ball +stop_sign +suitcase +surfboard +teddy_bear +tennis_racket +tie +toaster +toilet +toothbrush +traffic_light +train +truck +tv +umbrella +vase +wine_glass +zebra +``` + +Then, copy and add the above categories to Label-Studio and click Save. + +![](https://cdn.vansin.top/picgo20230330134027.png) + +In the Settings, click Add Model to add the RTMDet backend inference service. + +![](https://cdn.vansin.top/picgo20230330134320.png) + +Click Validate and Save, and then click Start Labeling. + +![](https://cdn.vansin.top/picgo20230330134424.png) + +If you see Connected as shown below, the backend inference service has been successfully added. + +![](https://cdn.vansin.top/picgo20230330134554.png) + +## Start Semi-Automatic Labeling + +Click on Label to start labeling. + +![](https://cdn.vansin.top/picgo20230330134804.png) + +We can see that the RTMDet backend inference service has successfully returned the predicted results and displayed them on the image. However, we noticed that the predicted bounding boxes for the cats are a bit too large and not very accurate. + +![](https://cdn.vansin.top/picgo20230403104419.png) + +We manually adjust the position of the cat bounding box, and then click Submit to complete the annotation of this image. + +![](https://cdn.vansin.top/picgo/20230403105923.png) + +After submitting all images, click export to export the labeled dataset in COCO format. + +![](https://cdn.vansin.top/picgo20230330135921.png) + +Use VS Code to open the unzipped folder to see the labeled dataset, which includes the images and the annotation files in JSON format. + +![](https://cdn.vansin.top/picgo20230330140321.png) + +At this point, the semi-automatic labeling is complete. We can use this dataset to train a more accurate model in MMDetection and then continue semi-automatic labeling on newly collected images with this model. This way, we can iteratively expand the high-quality dataset and improve the accuracy of the model. + +## Use MMYOLO as the Backend Inference Service + +If you want to use Label-Studio in MMYOLO, you can refer to replacing the config_file and checkpoint_file with the configuration file and weight file of MMYOLO when starting the backend inference service. + +```shell +cd path/to/mmetection + +label-studio-ml start projects/LabelStudio/backend_template --with \ +config_file= path/to/mmyolo_config.py \ +checkpoint_file= path/to/mmyolo_weights.pth \ +device=cpu \ +--port 8003 +# device=cpu is for using CPU inference. If using GPU inference, replace cpu with cuda:0. +``` + +Rotation object detection and instance segmentation are still under development, please stay tuned. diff --git a/docs/zh_cn/user_guides/index.rst b/docs/zh_cn/user_guides/index.rst index 0c413db58f0..5abc50ad1cd 100644 --- a/docs/zh_cn/user_guides/index.rst +++ b/docs/zh_cn/user_guides/index.rst @@ -31,3 +31,4 @@ MMDetection 在 `Model Zoo =2.0.0rc0" +# 安装 mmcv 的过程中会自动安装 mmengine +``` + +安装 MMDetection + +```shell +git clone https://github.com/open-mmlab/mmdetection -b dev-3.x +cd mmdetection +pip install -v -e . +``` + +安装 Label-Studio 和 label-studio-ml-backend + +```shell +# 安装 label-studio 需要一段时间,如果找不到版本请使用官方源 +pip install label-studio==1.7.2 +pip install label-studio-ml==1.0.9 +``` + +下载rtmdet权重 + +```shell +cd path/to/mmetection +mkdir work_dirs +cd work_dirs +wget https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth +``` + +## 启动服务 + +启动 RTMDet 后端推理服务: + +```shell +cd path/to/mmetection + +label-studio-ml start projects/LabelStudio/backend_template --with \ +config_file=configs/rtmdet/rtmdet_m_8xb32-300e_coco.py \ +checkpoint_file=./work_dirs/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth \ +device=cpu \ +--port 8003 +# device=cpu 为使用 CPU 推理,如果使用 GPU 推理,将 cpu 替换为 cuda:0 +``` + +![](https://cdn.vansin.top/picgo20230330131601.png) + +此时,RTMDet 后端推理服务已经启动,后续在 Label-Studio Web 系统中配置 http://localhost:8003 后端推理服务即可。 + +现在启动 Label-Studio 网页服务: + +```shell +label-studio start +``` + +![](https://cdn.vansin.top/picgo20230330132913.png) + +打开浏览器访问 [http://localhost:8080/](http://localhost:8080/) 即可看到 Label-Studio 的界面。 + +![](https://cdn.vansin.top/picgo20230330133118.png) + +我们注册一个用户,然后创建一个 RTMDet-Semiautomatic-Label 项目。 + +![](https://cdn.vansin.top/picgo20230330133333.png) + +我们通过下面的方式下载好示例的喵喵图片,点击 Data Import 导入需要标注的猫图片。 + +```shell +cd path/to/mmetection +mkdir data && cd data + +wget https://download.openmmlab.com/mmyolo/data/cat_dataset.zip && unzip cat_dataset.zip +``` + +![](https://cdn.vansin.top/picgo20230330133628.png) + +![](https://cdn.vansin.top/picgo20230330133715.png) + +然后选择 Object Detection With Bounding Boxes 模板 + +![](https://cdn.vansin.top/picgo20230330133807.png) + +```shell +airplane +apple +backpack +banana +baseball_bat +baseball_glove +bear +bed +bench +bicycle +bird +boat +book +bottle +bowl +broccoli +bus +cake +car +carrot +cat +cell_phone +chair +clock +couch +cow +cup +dining_table +dog +donut +elephant +fire_hydrant +fork +frisbee +giraffe +hair_drier +handbag +horse +hot_dog +keyboard +kite +knife +laptop +microwave +motorcycle +mouse +orange +oven +parking_meter +person +pizza +potted_plant +refrigerator +remote +sandwich +scissors +sheep +sink +skateboard +skis +snowboard +spoon +sports_ball +stop_sign +suitcase +surfboard +teddy_bear +tennis_racket +tie +toaster +toilet +toothbrush +traffic_light +train +truck +tv +umbrella +vase +wine_glass +zebra +``` + +然后将上述类别复制添加到 Label-Studio,然后点击 Save。 + +![](https://cdn.vansin.top/picgo20230330134027.png) + +然后在设置中点击 Add Model 添加 RTMDet 后端推理服务。 + +![](https://cdn.vansin.top/picgo20230330134320.png) + +点击 Validate and Save,然后点击 Start Labeling。 + +![](https://cdn.vansin.top/picgo20230330134424.png) + +看到如下 Connected 就说明后端推理服务添加成功。 + +![](https://cdn.vansin.top/picgo20230330134554.png) + +## 开始半自动化标注 + +点击 Label 开始标注 + +![](https://cdn.vansin.top/picgo20230330134804.png) + +我们可以看到 RTMDet 后端推理服务已经成功返回了预测结果并显示在图片上,我们可以发现这个喵喵预测的框有点大。 + +![](https://cdn.vansin.top/picgo20230403104419.png) + +我们手工拖动框,修正一下框的位置,得到以下修正过后的标注,然后点击 Submit,本张图片就标注完毕了。 + +![](https://cdn.vansin.top/picgo/20230403105923.png) + +我们 submit 完毕所有图片后,点击 exprot 导出 COCO 格式的数据集,就能把标注好的数据集的压缩包导出来了。 + +![](https://cdn.vansin.top/picgo20230330135921.png) + +用 vscode 打开解压后的文件夹,可以看到标注好的数据集,包含了图片和 json 格式的标注文件。 + +![](https://cdn.vansin.top/picgo20230330140321.png) + +到此半自动化标注就完成了,我们可以用这个数据集在 MMDetection 训练精度更高的模型了,训练出更好的模型,然后再用这个模型继续半自动化标注新采集的图片,这样就可以不断迭代,扩充高质量数据集,提高模型的精度。 + +## 使用 MMYOLO 作为后端推理服务 + +如果想在 MMYOLO 中使用 Label-Studio,可以参考在启动后端推理服务时,将 config_file 和 checkpoint_file 替换为 MMYOLO 的配置文件和权重文件即可。 + +```shell +cd path/to/mmetection + +label-studio-ml start projects/LabelStudio/backend_template --with \ +config_file= path/to/mmyolo_config.py \ +checkpoint_file= path/to/mmyolo_weights.pth \ +device=cpu \ +--port 8003 +# device=cpu 为使用 CPU 推理,如果使用 GPU 推理,将 cpu 替换为 cuda:0 +``` + +旋转目标检测和实例分割还在支持中,敬请期待。 diff --git a/projects/LabelStudio/backend_template/_wsgi.py b/projects/LabelStudio/backend_template/_wsgi.py new file mode 100644 index 00000000000..1f8fb68cdf8 --- /dev/null +++ b/projects/LabelStudio/backend_template/_wsgi.py @@ -0,0 +1,145 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import logging +import logging.config +import os + +logging.config.dictConfig({ + 'version': 1, + 'formatters': { + 'standard': { + 'format': + '[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s' # noqa E501 + } + }, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'level': 'DEBUG', + 'stream': 'ext://sys.stdout', + 'formatter': 'standard' + } + }, + 'root': { + 'level': 'ERROR', + 'handlers': ['console'], + 'propagate': True + } +}) + +_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json') + + +def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH): + if not os.path.exists(config_path): + return dict() + with open(config_path) as f: + config = json.load(f) + assert isinstance(config, dict) + return config + + +if __name__ == '__main__': + + from label_studio_ml.api import init_app + + from projects.LabelStudio.backend_template.mmdetection import MMDetection + + parser = argparse.ArgumentParser(description='Label studio') + parser.add_argument( + '-p', + '--port', + dest='port', + type=int, + default=9090, + help='Server port') + parser.add_argument( + '--host', dest='host', type=str, default='0.0.0.0', help='Server host') + parser.add_argument( + '--kwargs', + '--with', + dest='kwargs', + metavar='KEY=VAL', + nargs='+', + type=lambda kv: kv.split('='), + help='Additional LabelStudioMLBase model initialization kwargs') + parser.add_argument( + '-d', + '--debug', + dest='debug', + action='store_true', + help='Switch debug mode') + parser.add_argument( + '--log-level', + dest='log_level', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], + default=None, + help='Logging level') + parser.add_argument( + '--model-dir', + dest='model_dir', + default=os.path.dirname(__file__), + help='Directory models are store', + ) + parser.add_argument( + '--check', + dest='check', + action='store_true', + help='Validate model instance before launching server') + + args = parser.parse_args() + + # setup logging level + if args.log_level: + logging.root.setLevel(args.log_level) + + def isfloat(value): + try: + float(value) + return True + except ValueError: + return False + + def parse_kwargs(): + param = dict() + for k, v in args.kwargs: + if v.isdigit(): + param[k] = int(v) + elif v == 'True' or v == 'true': + param[k] = True + elif v == 'False' or v == 'False': + param[k] = False + elif isfloat(v): + param[k] = float(v) + else: + param[k] = v + return param + + kwargs = get_kwargs_from_config() + + if args.kwargs: + kwargs.update(parse_kwargs()) + + if args.check: + print('Check "' + MMDetection.__name__ + '" instance creation..') + model = MMDetection(**kwargs) + + app = init_app( + model_class=MMDetection, + model_dir=os.environ.get('MODEL_DIR', args.model_dir), + redis_queue=os.environ.get('RQ_QUEUE_NAME', 'default'), + redis_host=os.environ.get('REDIS_HOST', 'localhost'), + redis_port=os.environ.get('REDIS_PORT', 6379), + **kwargs) + + app.run(host=args.host, port=args.port, debug=args.debug) + +else: + # for uWSGI use + app = init_app( + model_class=MMDetection, + model_dir=os.environ.get('MODEL_DIR', os.path.dirname(__file__)), + redis_queue=os.environ.get('RQ_QUEUE_NAME', 'default'), + redis_host=os.environ.get('REDIS_HOST', 'localhost'), + redis_port=os.environ.get('REDIS_PORT', 6379)) diff --git a/projects/LabelStudio/backend_template/mmdetection.py b/projects/LabelStudio/backend_template/mmdetection.py new file mode 100644 index 00000000000..f25e80e8fc9 --- /dev/null +++ b/projects/LabelStudio/backend_template/mmdetection.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import io +import json +import logging +import os +from urllib.parse import urlparse + +import boto3 +from botocore.exceptions import ClientError +from label_studio_ml.model import LabelStudioMLBase +from label_studio_ml.utils import (DATA_UNDEFINED_NAME, get_image_size, + get_single_tag_keys) +from label_studio_tools.core.utils.io import get_data_dir + +from mmdet.apis import inference_detector, init_detector + +logger = logging.getLogger(__name__) + + +class MMDetection(LabelStudioMLBase): + """Object detector based on https://github.com/open-mmlab/mmdetection.""" + + def __init__(self, + config_file=None, + checkpoint_file=None, + image_dir=None, + labels_file=None, + score_threshold=0.5, + device='cpu', + **kwargs): + + super(MMDetection, self).__init__(**kwargs) + config_file = config_file or os.environ['config_file'] + checkpoint_file = checkpoint_file or os.environ['checkpoint_file'] + self.config_file = config_file + self.checkpoint_file = checkpoint_file + self.labels_file = labels_file + # default Label Studio image upload folder + upload_dir = os.path.join(get_data_dir(), 'media', 'upload') + self.image_dir = image_dir or upload_dir + logger.debug( + f'{self.__class__.__name__} reads images from {self.image_dir}') + if self.labels_file and os.path.exists(self.labels_file): + self.label_map = json_load(self.labels_file) + else: + self.label_map = {} + + self.from_name, self.to_name, self.value, self.labels_in_config = get_single_tag_keys( # noqa E501 + self.parsed_label_config, 'RectangleLabels', 'Image') + schema = list(self.parsed_label_config.values())[0] + self.labels_in_config = set(self.labels_in_config) + + # Collect label maps from `predicted_values="airplane,car"` attribute in