Skip to content

Commit

Permalink
needs more testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Mirtia committed Jul 3, 2023
1 parent 5fdc3e9 commit 6f336fb
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 9 deletions.
1 change: 1 addition & 0 deletions crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# The Crawler class checks if a directory and file exist, and creates them if they don't.
# It's the base class for the alternative crawlers.
# Eventually, I will put more functionalities in this class, this is just not a class.
class Crawler:

def __init__(self, output_dir, input_file, prefix):
Expand Down
2 changes: 1 addition & 1 deletion google_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from crawler import Crawler

# Unused
# The GoogleCrawler class is a subclass of Crawler that uses the Google Images Search API to crawl and
# download images based on input queries and search parameters. There is a limitation on the number of queries
# so I ended up crawling Wikimedia Commons.

class GoogleCrawler(Crawler):

def __init__(self, output_dir, input_file, prefix, parameters_file):
Expand Down
11 changes: 6 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse

from google_crawler import GoogleCrawler
from image_upscaler import ImageUpscaler
from wikimedia_crawler import WikimediaCrawler
from model import StyleModel
Expand All @@ -11,12 +10,10 @@ def main():
parser.add_argument("-u", "--upscale", action="store_true", help="Enable upscaling")
parser.add_argument("-c", "--crawl", action="store_true", help="Enable crawling")
parser.add_argument("-t", "--train", action="store_true", help="Train model")
parser.add_argument("-i", "--input", dest="input", help="Specify input image to test model")
parser.add_argument("-f", "--file", dest="file", help="Specify input file", required=True)
parser.add_argument("-o", "--output", dest="output", help="Specify output file", required=True)
args = parser.parse_args()
# Example parameters
# --file input/wikimedia_input
# --output output

if args.crawl:
crawler = WikimediaCrawler(args.file, args.output, "wikimedia")
Expand All @@ -28,7 +25,11 @@ def main():

if args.train:
model = StyleModel(args.file, args.output)
model.train_model()
model.train()

if args.input:
model = StyleModel(args.file, args.output)
model.classify(args.input)

if __name__ == "__main__":
main()
62 changes: 59 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import shutil
from imageai.Classification.Custom import ClassificationModelTrainer
from imageai.Classification.Custom import ClassificationModelTrainer, CustomImageClassification

# The `StyleModel` class is a Python class that provides methods for creating a directory structure
# for training and testing data, training a model using the ResNet50 architecture, and classifying
# images using a pre-trained ResNet50 model.
class StyleModel:

TRAIN_SET_SIZE = 1000
Expand All @@ -15,6 +18,11 @@ def __init__(self, input_dir, output_dir, skip=True):
self.model_trainer = ClassificationModelTrainer()

def __create_dir_structure(self):
"""
The function creates a directory structure for training and testing data, and moves files from
the input directory to the appropriate train and test directories based on a specified train set
size and test set size.
"""
os.makedirs(self.output_dir, exist_ok=True)
train_dir = os.path.join(self.output_dir, "train")
test_dir = os.path.join(self.output_dir, "test")
Expand Down Expand Up @@ -43,8 +51,56 @@ def __create_dir_structure(self):
break
count += 1

def train_model(self):
# https://github.com/OlafenwaMoses/ImageAI/blob/master/imageai/Classification/CUSTOMTRAINING.md
def train(self):
"""
The function trains a model using the ResNet50 architecture with a specified number of
experiments and batch size.
See https://github.com/OlafenwaMoses/ImageAI/blob/master/imageai/Classification/CUSTOMTRAINING.md
"""
self.model_trainer.setModelTypeAsResNet50()
self.model_trainer.setDataDirectory(self.output_dir)
self.model_trainer.trainModel(num_experiments=100, batch_size=32)

def __get_paths(self, path):
"""
The function "__get_paths" takes a path as input and returns the paths of a model file and a JSON
file within that path.
:param path: The `path` parameter is the directory path where the files are located
:return: two variables: model_path and json_path.
"""
model_path, json_path = None, None
if os.path.isdir(path):
files = os.listdir(path)
for file in files:
file_path = os.path.join(path, file)
if os.path.isfile(file_path) and file_path.endswith(".json"):
json_path = file_path
else:
model_path = file_path
return model_path, json_path


def classify(self, input_img):
"""
The `classify` function uses a pre-trained ResNet50 model to classify an input image and prints
the top 10 predictions along with their probabilities.
:param input_img: The input_img parameter is the image that you want to classify. It should be
the path to the image file or a PIL image object
"""
prediction = CustomImageClassification()
prediction.setModelTypeAsResNet50()

model_path, json_path = self.__get_paths(os.path.join(self.output_dir, "models"))

prediction.setModelPath(model_path)
prediction.setJsonPath(json_path)
prediction.loadModel()
predictions, probabilities = prediction.classifyImage(input_img, result_count=10)

print(f"Predictions for image: {input_img}")
print("============================================================")
for eachPrediction, eachProbability in zip(predictions, probabilities):
print(eachPrediction + " : " + str(eachProbability))
print("============================================================")
Binary file added tests/monet_impressionism.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 6f336fb

Please sign in to comment.