Skip to content

Commit

Permalink
refactor: auto-labeling
Browse files Browse the repository at this point in the history
  • Loading branch information
QIN2DIM committed Oct 20, 2023
1 parent 1c8ea44 commit 61e6b6a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 13 deletions.
2 changes: 1 addition & 1 deletion automation/assets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def merge(self, fd: Path, td: Path):


def run():
sources = "hat"
sources = "https://github.com/QIN2DIM/hcaptcha-challenger/issues/851"
am = AssetsManager.from_sources(sources)
am.execute()

Expand Down
23 changes: 14 additions & 9 deletions automation/auto_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
from pathlib import Path
from typing import List, Tuple

import torch
from PIL import Image
from hcaptcha_challenger import split_prompt_message, prompt2task, label_cleaning
from tqdm import tqdm
from transformers import pipeline


@dataclass
Expand All @@ -26,11 +24,16 @@ class AutoLabeling:
pending_tasks: List[Path] = field(default_factory=list)

checkpoint = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
device = "cuda" if torch.cuda.is_available() else "cpu"
task = "zero-shot-image-classification"

def load_zero_shot_model(self):
detector = pipeline(task=self.task, model=self.checkpoint, device=self.device, batch_size=8)
import torch
from transformers import pipeline

device = "cuda" if torch.cuda.is_available() else "cpu"
task = "zero-shot-image-classification"

detector = pipeline(task=task, model=self.checkpoint, device=device, batch_size=8)

return detector

@classmethod
Expand Down Expand Up @@ -120,11 +123,13 @@ def run(prompt: str, negative_labels: List[str], **kwargs):
al = AutoLabeling.from_prompt(positive_label, candidate_labels, images_dir)
output_dir = al.execute(limit=kwargs.get("limit"))

if "win32" in sys.platform:
if "win32" in sys.platform and output_dir:
os.startfile(output_dir)


if __name__ == "__main__":
run(
prompt="vr_headset", negative_labels=["phone", "keyboard", "drone", "3d printer"], limit=500
)
# prompt to negative labels
prompt2neg = {"motorized machine": ["plant", "mountain", "natural landscape"]}

for p, nl in prompt2neg.items():
run(p, nl, limit=500)
2 changes: 1 addition & 1 deletion automation/continue_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,4 @@ def run(prompt: str, model_name: str | None = None):


if __name__ == "__main__":
run("hat")
run("pair_of_roller_skates")
5 changes: 3 additions & 2 deletions automation/mini_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ def upgrade_objects(aid):
# "smartwatch": "smartwatch2309",
# "hat": "hat2310",
# "vineyard": "vineyard2309",
"nested_electronic_device": "nested_electronic_device2309",
# "pair_of_roller_skates": "pair_of_roller_skates2310",
"motorized_machine": "motorized_machine2309"
}
# fmt:on

quick_train()
aid = quick_development()
# upgrade_objects(aid)
upgrade_objects(aid)
print(aid)

0 comments on commit 61e6b6a

Please sign in to comment.