Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
RihabFekii committed Apr 9, 2023
1 parent ee9f8ea commit 54e3328
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 21 deletions.
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
dvc==2.51.0
ultralytics==8.0.58
python-dotenv==1.0.0
mlflow==2.2.2


5 changes: 2 additions & 3 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,12 @@

# save model
model_path = save_model(experiment_name=params['name'])
# log model path with mlflow
# log model artifact to mlflow
mlflow.log_artifact(model_path)

# save metrics csv file and training params
save_metrics_and_params(experiment_name=params['name'])
# convert metrics from csv to json
convert_metrics_csv_to_json(metrics_path, params['name'])




Expand Down
20 changes: 2 additions & 18 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
import shutil
from pathlib import Path
import pandas as pd
import json


ROOT_DIR = Path(__file__).resolve().parents[1] # root directory absolute path

Expand All @@ -13,20 +12,7 @@ def save_model(experiment_name: str):
model_weights = experiment_name + "/weights/best.pt"
path_model_weights = os.path.join(ROOT_DIR, "runs/detect", model_weights)

return shutil.copy(src=path_model_weights, dst=f'{ROOT_DIR}/models/model.pt')


def csv_to_json(csv_file_path:str, dest_file_path: str):
df = pd.read_csv(csv_file_path)
df.to_json(dest_file_path, orient="index")


def convert_metrics_csv_to_json(dest_file_path: str, experiment_name:str):
# convert metrics from csv to json format in order to track the with DVC
if os.path.isdir('runs'):
path_metrics = os.path.join(ROOT_DIR, "runs/detect", experiment_name)
path_metrics = os.path.join(path_metrics, 'results.csv')
csv_to_json(path_metrics, dest_file_path)
shutil.copy(src=path_model_weights, dst=f'{ROOT_DIR}/models/model.pt')


def save_metrics_and_params(experiment_name: str) -> None:
Expand All @@ -43,5 +29,3 @@ def save_metrics_and_params(experiment_name: str) -> None:
# save training params
shutil.copy(src=f'{path_metrics}/args.yaml', dst=f'{ROOT_DIR}/reports/train_params.yaml')


save_model('yolov8s_exp_v0')

0 comments on commit 54e3328

Please sign in to comment.