-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
79 lines (67 loc) · 2.94 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
This script is used to process images in a directory and classify them as NSFW or SFW.
The script uses a pre-trained model to classify images and copy them to the output directory if they are NSFW.
If an error occurs while processing an image, the image is copied to the error directory.
Model: Falconsai/nsfw_image_detection
Usage: python main.py -I <input_directory> -O <output_directory> -E <error_directory>
Arguments:
-I, --input: Input directory containing images to be processed.
-O, --output: Output directory to copy NSFW images.
-E, --error: Error directory to copy images that could not be processed.
"""
import os
import shutil
import torch
from PIL import Image
from transformers import AutoModelForImageClassification, ViTImageProcessor
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import argparse
# Initialize model and device
model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
model.to(device)
# Get directories from command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('-I', '--input', help='Input directory', required=True)
parser.add_argument('-O', '--output', help='Output directory')
parser.add_argument('-E', '--error', help='Error directory')
args = parser.parse_args()
# Set default values if directories are not specified
output_directory = args.output if args.output else 'output'
nsfw_directory = f'{output_directory}/nsfw'
error_directory = args.error if args.error else f'{output_directory}/error'
# Create directories if they don't exist
os.makedirs(output_directory, exist_ok=True)
os.makedirs(error_directory, exist_ok=True)
os.makedirs(nsfw_directory, exist_ok=True)
# Function to process each image and copy if necessary
def process_image(image_path):
try:
img = Image.open(image_path)
inputs = processor(images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_label = logits.argmax(-1).item()
label = model.config.id2label[predicted_label]
if label.lower() == 'nsfw':
shutil.copy(image_path, nsfw_directory)
return image_path, label
except:
shutil.copy(image_path, error_directory)
return image_path, 'error'
def process_images(image_paths):
with ThreadPoolExecutor(max_workers=None) as executor:
results = list(tqdm(executor.map(process_image, image_paths), total=len(image_paths)))
return results
# Collect all image paths
image_files = [
os.path.join(root, file)
for root, _, files in os.walk(args.input)
for file in files
if file.lower().endswith(('png', 'jpg', 'jpeg', 'bmp', 'gif'))
]
results = process_images(image_files)