Skip to content

Commit

Permalink
feat(auo-labeling): laion/CLIP-ViT-H-14-laion2B-s32B-b79K (#67)
Browse files Browse the repository at this point in the history
* feat(auo-labeling): laion/CLIP-ViT-H-14-laion2B-s32B-b79K

Co-Authored-By: Bingjie YAN <[email protected]>

* Update auto_labeling.py

Co-Authored-By: Bingjie YAN <[email protected]>

* format

Co-Authored-By: Bingjie YAN <[email protected]>

---------

Co-authored-by: Bingjie YAN <[email protected]>
  • Loading branch information
QIN2DIM and beiyuouo committed Oct 20, 2023
1 parent d01c112 commit 1c8ea44
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 5 deletions.
130 changes: 130 additions & 0 deletions automation/auto_labeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-
# Time : 2023/10/20 17:28
# Author : QIN2DIM
# GitHub : https://github.com/QIN2DIM
# Description: zero-shot image classification
import os
import shutil
import sys
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import List, Tuple

import torch
from PIL import Image
from hcaptcha_challenger import split_prompt_message, prompt2task, label_cleaning
from tqdm import tqdm
from transformers import pipeline


@dataclass
class AutoLabeling:
positive_label: str = field(default=str)
candidate_labels: List[str] = field(default_factory=list)
images_dir: Path = field(default=Path)
pending_tasks: List[Path] = field(default_factory=list)

checkpoint = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
device = "cuda" if torch.cuda.is_available() else "cpu"
task = "zero-shot-image-classification"

def load_zero_shot_model(self):
detector = pipeline(task=self.task, model=self.checkpoint, device=self.device, batch_size=8)
return detector

@classmethod
def from_prompt(cls, positive_label: str, candidate_labels: List[str], images_dir: Path):
images_dir.mkdir(parents=True, exist_ok=True)

pending_tasks: List[Path] = []
for image_name in os.listdir(images_dir):
image_path = images_dir.joinpath(image_name)
if image_path.is_file():
pending_tasks.append(image_path)

return cls(
positive_label=positive_label,
candidate_labels=candidate_labels,
images_dir=images_dir,
pending_tasks=pending_tasks,
)

def valid(self):
if not self.pending_tasks:
print("No pending tasks")
return
if len(self.candidate_labels) <= 2:
print(f">> Please enter at least three class names - {self.candidate_labels=}")
return

return True

def mkdir(self) -> Tuple[Path, Path]:
__formats = ("%Y-%m-%d %H:%M:%S.%f", "%Y%m%d%H%M")
now = datetime.strptime(str(datetime.now()), __formats[0]).strftime(__formats[1])
yes_dir = self.images_dir.joinpath(now, "yes")
bad_dir = self.images_dir.joinpath(now, "bad")
yes_dir.mkdir(parents=True, exist_ok=True)
bad_dir.mkdir(parents=True, exist_ok=True)

return yes_dir, bad_dir

def execute(self, limit: int = None):
if not self.valid():
return

# Format datafolder
yes_dir, bad_dir = self.mkdir()

# Load zero-shot model
detector = self.load_zero_shot_model()

total = len(self.pending_tasks)
desc_in = f'"{self.checkpoint}/{self.images_dir.name}"'
limit = limit or total

with tqdm(total=total, desc=f"Labeling | {desc_in}") as progress:
for image_path in self.pending_tasks[:limit]:
image = Image.open(image_path)

# Binary Image classification
predictions = detector(image, candidate_labels=self.candidate_labels)

# Move positive cases to yes/
# Move negative cases to bad/
if predictions[0]["label"] == self.positive_label:
output_path = yes_dir.joinpath(image_path.name)
else:
output_path = bad_dir.joinpath(image_path.name)
shutil.move(image_path, output_path)

progress.update(1)

return yes_dir.parent


def run(prompt: str, negative_labels: List[str], **kwargs):
prompt = prompt.replace("_", " ")

task_name = prompt2task(prompt)

project_dir = Path(__file__).parent.parent
images_dir = project_dir.joinpath("database2309", task_name)

positive_label = split_prompt_message(label_cleaning(prompt), "en")
candidate_labels = [positive_label]
if isinstance(negative_labels, list) and len(negative_labels) != 0:
candidate_labels.extend(negative_labels)

al = AutoLabeling.from_prompt(positive_label, candidate_labels, images_dir)
output_dir = al.execute(limit=kwargs.get("limit"))

if "win32" in sys.platform:
os.startfile(output_dir)


if __name__ == "__main__":
run(
prompt="vr_headset", negative_labels=["phone", "keyboard", "drone", "3d printer"], limit=500
)
4 changes: 3 additions & 1 deletion automation/check_yolo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

def run():
model_name = "appears_only_once_2309_yolov8s-seg.onnx"
images_dir = "tmp_dir/image_label_area_select/please click on the object that appears only once/default"
images_dir = (
"tmp_dir/image_label_area_select/please click on the object that appears only once/default"
)

this_dir = Path(__file__).parent
output_dir = this_dir.joinpath("yolo_mocker")
Expand Down
2 changes: 1 addition & 1 deletion automation/continue_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Time : 2023/9/24 15:07
# Author : QIN2DIM
# GitHub : https://github.com/QIN2DIM
# Description:
# Description: Continue labeling images using the exported ONNX ResNet model
import os
import shutil
import sys
Expand Down
2 changes: 0 additions & 2 deletions automation/requirements.txt

This file was deleted.

5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ torch>=2.0.1+cu118
scikit-learn
pillow
onnx
numpy~=1.24.1
numpy~=1.26.0
opencv-python~=4.8.0.76
pyyaml
loguru
fire
transformers
hcaptcha-challenger[sentinel]
bs4

0 comments on commit 1c8ea44

Please sign in to comment.