Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PROC-309] | code changes for creating yolov7 a new package #1

Open
wants to merge 17 commits into
base: release
Choose a base branch
from
Prev Previous commit
Next Next commit
code refractoring
  • Loading branch information
sajal-infoedge committed Sep 1, 2023
commit f7bb04d24de06b7845c6ba4c655b305bcf5331fe
2 changes: 1 addition & 1 deletion yolov7/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ class ObjectDetectionConstants(Enum):
OBJECT_DETECTION_PROCESSOR = 'cpu'
OBJECT_DETECTION_MODEL_NAME = 'yolov7.pt'
OBJECT_DETECTION_IMAGE_DIMENSIONS = 3
OBJECT_DETECTION_OBJECTS = {'cell phone': []}
OBJECT_DETECTION_OBJECTS = ['cell phone']
S3_BUCKET_NAME = os.environ.get('ASSET_BUCKET')

22 changes: 12 additions & 10 deletions yolov7/object_detection_yolov7.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@ class Yolov7:
objects_detected_confidence_mapping = []

def __init__(self, objects_to_be_detected=None):

self.image_size = ObjectDetectionConstants.OBJECT_DETECTION_IMAGE_SIZE.value
self.device = select_device(ObjectDetectionConstants.OBJECT_DETECTION_PROCESSOR.value)
self.convert_image = self.device.type != ObjectDetectionConstants.OBJECT_DETECTION_PROCESSOR.value
model_root_path = module_path.replace(os.getcwd() + "/", "")
model_path = model_root_path + f'/{ObjectDetectionConstants.OBJECT_DETECTION_MODEL_NAME.value}'
self.load_model_weights(model_path)
self.model = attempt_load([model_path],
map_location=self.device)
self.model = attempt_load([model_path], map_location=self.device)

if objects_to_be_detected is None:
self.objects_detected_confidence_mapping = copy.deepcopy(ObjectDetectionConstants.OBJECT_DETECTION_OBJECTS.value)
else:
self.objects_detected_confidence_mapping = self.get_object_confidence_mapping(objects_to_be_detected)
objects_to_be_detected = copy.deepcopy(ObjectDetectionConstants.OBJECT_DETECTION_OBJECTS.value)
self.objects_to_be_detected = objects_to_be_detected
self.objects_detected_confidence_mapping = self.get_object_confidence_mapping()

self.stride = int(self.model.stride.max()) # model stride


Expand All @@ -51,14 +52,14 @@ def load_model_weights(self, model_path):
os.system(f"aws s3 cp s3:https://{ObjectDetectionConstants.S3_BUCKET_NAME.value}/{ObjectDetectionConstants.OBJECT_DETECTION_MODEL_NAME.value} {model_path}")


def get_object_confidence_mapping(self, objects_to_be_detected):
def get_object_confidence_mapping(self):
"""get object and confidence mapping for objects list

:param objects_to_be_detected: list
:return: dict
"""
objects_detected_confidence_mapping = {}
for _object in objects_to_be_detected:
for _object in self.objects_to_be_detected:
objects_detected_confidence_mapping[_object] = []
return objects_detected_confidence_mapping

Expand Down Expand Up @@ -103,11 +104,12 @@ def detect_object(self, img):
for i in range(4):
rect[i] = int(rect[i])

objects_detected_confidence_mapping = copy.deepcopy(self.objects_detected_confidence_mapping)
for i, item in enumerate(prediction):
label = item[-1]
if label in list(self.objects_detected_confidence_mapping.keys()):
if label in list(objects_detected_confidence_mapping.keys()):
# item list containes boundary box till index 4 and index 4 is confidence
self.objects_detected_confidence_mapping[label].append({'bbox': item[:4], 'confidence': item[4]})
objects_detected_confidence_mapping[label].append({'bbox': item[:4], 'confidence': item[4]})

return self.objects_detected_confidence_mapping
return objects_detected_confidence_mapping