Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved W&B integration #2125

Merged
merged 84 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
ba39bfd
Init Commit
AyushExel Feb 2, 2021
5fcd1dc
new wandb integration
AyushExel Feb 3, 2021
c540e3b
Update
AyushExel Feb 3, 2021
8253b24
Use data_dict in test
AyushExel Feb 3, 2021
7f89535
Updates
AyushExel Feb 3, 2021
c149930
Update: scope of log_img
AyushExel Feb 3, 2021
49edb90
Update: scope of log_img
AyushExel Feb 3, 2021
7922683
Update
AyushExel Feb 3, 2021
e1e7179
Update: Fix logging conditions
AyushExel Feb 3, 2021
e632514
Add tqdm bar, support for .txt dataset format
AyushExel Feb 9, 2021
3e8f4ae
Improve Result table Logger
AyushExel Feb 21, 2021
cd094f3
Init Commit
AyushExel Feb 2, 2021
aa5231e
new wandb integration
AyushExel Feb 3, 2021
0fdf3d3
Update
AyushExel Feb 3, 2021
37a2ed6
Use data_dict in test
AyushExel Feb 3, 2021
ac7d4b1
Updates
AyushExel Feb 3, 2021
ebc1d18
Update: scope of log_img
AyushExel Feb 3, 2021
745a272
Update: scope of log_img
AyushExel Feb 3, 2021
7679454
Update
AyushExel Feb 3, 2021
b8210a7
Update: Fix logging conditions
AyushExel Feb 3, 2021
ac9a613
Add tqdm bar, support for .txt dataset format
AyushExel Feb 9, 2021
4f7c150
Improve Result table Logger
AyushExel Feb 21, 2021
b8bbfce
Merge branch 'wandb_clean' of https://github.com/AyushExel/yolov5 int…
AyushExel Feb 23, 2021
c1e6697
Add dataset creation in training script
AyushExel Feb 23, 2021
1948562
Change scope: self.wandb_run
AyushExel Feb 23, 2021
8848f3c
Add wandb-artifact:https:// natively
AyushExel Feb 25, 2021
deca116
Add suuport for logging dataset while training
AyushExel Feb 26, 2021
20185f2
Cleanup
AyushExel Feb 26, 2021
5287a79
Merge branch 'master' into wandb_clean
AyushExel Feb 26, 2021
e13994d
Fix: Merge conflict
AyushExel Feb 26, 2021
1080952
Fix: CI tests
AyushExel Feb 26, 2021
5a859d4
Automatically use wandb config
AyushExel Feb 27, 2021
519cb7d
Fix: Resume
AyushExel Feb 28, 2021
3242f52
Fix: CI
AyushExel Feb 28, 2021
8128216
Enhance: Using val_table
AyushExel Feb 28, 2021
043befa
More resume enhancement
AyushExel Feb 28, 2021
c2d98f0
FIX : CI
AyushExel Feb 28, 2021
dbb69f4
Add alias
AyushExel Feb 28, 2021
8505a58
Get useful opt config data
AyushExel Mar 1, 2021
04f8880
train.py cleanup
AyushExel Mar 2, 2021
27a33dd
Merge remote-tracking branch 'upstream/master' into wandb_clean
AyushExel Mar 2, 2021
54dee24
Cleanup train.py
AyushExel Mar 2, 2021
21a15a5
more cleanup
AyushExel Mar 2, 2021
d38c620
Cleanup| CI fix
AyushExel Mar 2, 2021
e5400ba
Reformat using PEP8
AyushExel Mar 3, 2021
45e2c55
FIX:CI
AyushExel Mar 3, 2021
75f31d0
Merge remote-tracking branch 'upstream/master' into wandb_clean
AyushExel Mar 6, 2021
613b102
rebase
AyushExel Mar 6, 2021
9772645
remove uneccesary changes
AyushExel Mar 6, 2021
cd1237e
remove uneccesary changes
AyushExel Mar 6, 2021
d172ba1
remove uneccesary changes
AyushExel Mar 6, 2021
7af0186
remove unecessary chage from test.py
AyushExel Mar 6, 2021
51dca6d
FIX: resume from local checkpoint
AyushExel Mar 8, 2021
1438483
FIX:resume
AyushExel Mar 8, 2021
e7d18c6
FIX:resume
AyushExel Mar 8, 2021
22d97a7
Reformat
AyushExel Mar 8, 2021
8e97cdf
Performance improvement
AyushExel Mar 9, 2021
2ffb643
Fix local resume
AyushExel Mar 9, 2021
7836d17
Fix local resume
AyushExel Mar 9, 2021
aa785ec
FIX:CI
AyushExel Mar 9, 2021
f97446e
Fix: CI
AyushExel Mar 9, 2021
807a0e1
Imporve image logging
AyushExel Mar 9, 2021
20b4450
(:(:Redo CI tests:):)
AyushExel Mar 9, 2021
db81c64
Remember epochs when resuming
AyushExel Mar 9, 2021
25ff6b8
Remember epochs when resuming
AyushExel Mar 9, 2021
819ebec
Update DDP location
glenn-jocher Mar 10, 2021
b23a902
merge master
glenn-jocher Mar 14, 2021
f742857
PEP8 reformat
glenn-jocher Mar 14, 2021
350b8ab
0.25 confidence threshold
glenn-jocher Mar 14, 2021
395379e
reset train.py plots syntax to previous
glenn-jocher Mar 14, 2021
a06b25c
reset epochs completed syntax to previous
glenn-jocher Mar 14, 2021
cc49f6a
reset space to previous
glenn-jocher Mar 14, 2021
2d56697
remove brackets
glenn-jocher Mar 14, 2021
ba859a6
reset comment to previous
glenn-jocher Mar 14, 2021
52e3e71
Update: is_coco check, remove unused code
AyushExel Mar 14, 2021
ad1ad8f
Remove redundant print statement
AyushExel Mar 14, 2021
72dd23b
Remove wandb imports
AyushExel Mar 14, 2021
ac955ab
remove dsviz logger from test.py
AyushExel Mar 14, 2021
8bded54
Remove redundant change from test.py
AyushExel Mar 14, 2021
1aca390
remove redundant changes from train.py
AyushExel Mar 14, 2021
4c1c9bf
reformat and improvements
AyushExel Mar 20, 2021
f4923b4
Fix typo
AyushExel Mar 21, 2021
af23506
Merge branch 'master' of https://github.com/ultralytics/yolov5 into w…
AyushExel Mar 21, 2021
ca06d31
Add tqdm tqdm progress when scanning files, naming improvements
AyushExel Mar 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix: Resume
  • Loading branch information
AyushExel committed Feb 28, 2021
commit 519cb7d8c81d2f7ff9a9f53598710a1a19a3cd40
46 changes: 21 additions & 25 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
from utils.wandb_logging.wandb_utils import WandbLogger, get_id_and_model_name, check_wandb_config_file
from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id, check_wandb_config_file

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,22 +64,22 @@ def train(hyp, opt, device, tb_writer=None):
with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
is_coco = opt.data.endswith('coco.yaml')
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(
data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (
len(names), nc, opt.data) # check
# Logging- Doing this before checking the dataset. In case artifact links are being used, they need to be downloaded
# Logging- Doing this before checking the dataset. Might update data_dict
if rank in [-1, 0]:
opt.hyp = hyp # add hyperparameters
run_id = ckpt.get('wandb_id') if 'ckpt' in locals() else None
wandb_logger = WandbLogger(
opt, Path(opt.save_dir).stem, run_id, data_dict)
data_dict = wandb_logger.data_dict
if wandb_logger.wandb:
import wandb
weights = opt.weights # WandbLogger might update weights path
loggers = {'wandb': wandb_logger.wandb} # loggers dict

nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(
data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (
len(names), nc, opt.data) # check
# Model
pretrained = weights.endswith('.pt')
if pretrained:
Expand Down Expand Up @@ -165,12 +165,10 @@ def lf(x): return (1 - x / (epochs - 1)) * \
if ckpt['optimizer'] is not None:
optimizer.load_state_dict(ckpt['optimizer'])
best_fitness = ckpt['best_fitness']

# EMA
if ema and ckpt.get('ema'):
ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
ema.updates = ckpt['ema'][1]

# Results
if ckpt.get('training_results') is not None:
results_file.write_text(
Expand Down Expand Up @@ -555,7 +553,6 @@ def lf(x): return (1 - x / (epochs - 1)) * \
help='name of model artifact to resume training from.overwirtes local --weights file')
opt = parser.parse_args()

opt.data = check_wandb_config_file(opt.data)
# Set DDP variables
opt.world_size = int(os.environ['WORLD_SIZE']
) if 'WORLD_SIZE' in os.environ else 1
Expand All @@ -566,28 +563,26 @@ def lf(x): return (1 - x / (epochs - 1)) * \
check_requirements()

# Resume
if opt.resume: # resume an interrupted run
wandb_run = resume_and_get_id(opt)
if opt.resume and (not wandb_run): # resume an interrupted run
# specified or most recent path
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run()
wandb_run, _ = get_id_and_model_name(opt.resume)
assert os.path.isfile(
ckpt) or wandb_run, 'ERROR: --resume checkpoint does not exist'
if not wandb_run: # resuming from local checkpoint
apriori = opt.global_rank, opt.local_rank
with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
opt = argparse.Namespace(
**yaml.load(f, Loader=yaml.SafeLoader)) # replace
opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate
logger.info('Resuming training from %s' % ckpt)
else:
opt.save_dir = increment_path(Path(
opt.project) / ('run_'+wandb_run.id), exist_ok=opt.exist_ok) # Resume run from wandb
apriori = opt.global_rank, opt.local_rank
with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
opt = argparse.Namespace(
**yaml.load(f, Loader=yaml.SafeLoader)) # replace
opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate
logger.info('Resuming training from %s' % ckpt)
else:
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(
opt.cfg), check_file(opt.hyp) # check files
assert len(opt.cfg) or len(
opt.weights), 'either --cfg or --weights must be specified'
print(opt.data)
opt.data = check_wandb_config_file(opt.data) #check if wandb config is present
# extend to 2 sizes (train, test)
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size)))
opt.name = 'evolve' if opt.evolve else opt.name
Expand All @@ -607,8 +602,9 @@ def lf(x): return (1 - x / (epochs - 1)) * \
opt.batch_size = opt.total_batch_size // opt.world_size

# Hyperparameters
with open(opt.hyp) as f:
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
if not isinstance(opt.hyp, dict):
with open(opt.hyp) as f:
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps

# Train
logger.info(opt)
Expand Down
69 changes: 33 additions & 36 deletions utils/wandb_logging/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import sys
from datetime import datetime
from pathlib import Path
import argparse
import yaml
import os

from tqdm import tqdm
import torch
Expand All @@ -26,34 +28,33 @@ def remove_prefix(from_string, prefix):
return from_string[len(prefix):]

def check_wandb_config_file(data_config_file):
wandb_config = data_config_file.replace('.', '_wandb.') # updated data.yaml path
wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path
if Path(wandb_config).is_file():
return wandb_config
return data_config_file

def get_id_and_model_name(run_path):
def resume_and_get_id(opt):
# It's more elegant to stick to 1 wandb.init call, but as useful config data is overwritten in the WandbLogger's wandb.init call
if run_path.startswith(WANDB_ARTIFACT_PREFIX):
run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
run_id = run_path.stem
model_artifact_name = WANDB_ARTIFACT_PREFIX + 'run_' + run_id + '_model'
assert wandb, 'install wandb to resume wandb runs'
# Resume wandb-artifact:https:// runs here| workaround for not overwriting wandb.config
run = wandb.init(id=run_id, resume='allow')
return run, model_artifact_name
return None, None

if isinstance(opt.resume,str):
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
run_path = Path(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX))
run_id = run_path.stem
project = run_path.parent.stem
model_artifact_name = WANDB_ARTIFACT_PREFIX + 'run_' + run_id + '_model'
assert wandb, 'install wandb to resume wandb runs'
# Resume wandb-artifact:https:// runs here| workaround for not overwriting wandb.config
run = wandb.init(id=run_id, project=project, resume='allow')
opt.resume_from_artifact = model_artifact_name
return run
opt.resume_from_artifact = ''
return None

class WandbLogger():
def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
# Pre-training routine -- check for --resume and --upload_dataset
self.job_type = job_type
self.wandb, self.wandb_run = wandb, None
self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run
if self.job_type == 'Training':
run, model_artifact_name = self.check_resume(opt)
if run:
opt.resume = model_artifact_name
opt.save_period = run.config.save_period
if opt.upload_dataset:
data_dict = self.check_and_upload_dataset(
opt, name, data_dict, job_type)
Expand All @@ -66,19 +67,15 @@ def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
name=name,
job_type=job_type,
id=run_id) if not wandb.run else wandb.run
self.setup_training(opt, data_dict)
self.wandb_run.config.opt = vars(opt)
self.data_dict = self.setup_training(opt, data_dict)
if self.job_type == 'Dataset Creation':
self.data_dict = self.check_and_upload_dataset(
opt, name, data_dict, job_type)

def check_resume(self, opt):
if self.job_type == 'Training':
if isinstance(opt.resume, str):
return get_id_and_model_name(opt.resume)
return None, None

def check_and_upload_dataset(self, opt, name, data_dict, job_type):
assert wandb, 'Install wandb to upload dataset'
os.environ['WANDB_SILENT'] = 'true' #Reduce verbosity for dataset creation job
run = wandb.init(config=data_dict,
project='YOLOv5' if opt.project == 'runs/train' else Path(
opt.project).stem,
Expand All @@ -88,6 +85,7 @@ def check_and_upload_dataset(self, opt, name, data_dict, job_type):
opt.single_cls,
'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
wandb.finish() # Finish dataset creation run| ensures the dataset has uploaded completely before training starts
os.environ['WANDB_SILENT'] = 'false'
print("Created dataset config file ", config_path)
with open(config_path) as f:
wandb_data_dict = yaml.load(f, Loader=yaml.SafeLoader)
Expand All @@ -100,14 +98,13 @@ def setup_training(self, opt, data_dict):
if opt.bbox_interval == -1:
opt.bbox_interval = (
opt.epochs // 10) if opt.epochs > 10 else opt.epochs
if opt.resume:
modeldir, _ = self.download_model_artifact(opt.resume)
if opt.resume_from_artifact:
modeldir, _ = self.download_model_artifact(opt)
if modeldir:
self.weights = Path(modeldir) / "best.pt"
self.weights = Path(modeldir) / "last.pt"
opt.weights = str(self.weights)
# Advantage: Eliminates the need for config file to resume
data_dict = self.wandb_run.config.data_dict

# Advantage: Eliminates the need for config file to resume
data_dict = self.wandb_run.config.data_dict
self.train_artifact_path, self.train_artifact = \
self.download_dataset_artifact(
data_dict.get('train'), opt.artifact_alias)
Expand All @@ -126,6 +123,7 @@ def setup_training(self, opt, data_dict):
"run_" + wandb.run.id + "_progress", "evaluation")
self.result_table = wandb.Table(
["epoch", "id", "prediction", "avg_confidence"])
return data_dict

def download_dataset_artifact(self, path, alias):
if path.startswith(WANDB_ARTIFACT_PREFIX):
Expand All @@ -136,10 +134,10 @@ def download_dataset_artifact(self, path, alias):
return datadir, dataset_artifact
return None, None

def download_model_artifact(self, name):
if name.startswith(WANDB_ARTIFACT_PREFIX):
model_artifact = wandb.use_artifact(
remove_prefix(name, WANDB_ARTIFACT_PREFIX) + ":latest")
def download_model_artifact(self, opt):
if opt.resume_from_artifact.startswith(WANDB_ARTIFACT_PREFIX):
model_artifact_name = 'run_' + Path(opt.resume).stem + '_model'
model_artifact = wandb.use_artifact(model_artifact_name + ":latest")
assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
modeldir = model_artifact.download()
epochs_trained = model_artifact.metadata.get('epochs_trained')
Expand Down Expand Up @@ -179,8 +177,7 @@ def create_dataset_artifact(self, data_file, single_cls, project, overwrite_conf
str(Path(project) / 'train')
if data.get('val'):
data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
path = data_file if overwrite_config else data_file.replace(
'.', '_wandb.') # updated data.yaml path
path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
# download via artifact instead of predefined field 'download:'
data.pop('download', None)
with open(path, 'w') as f:
Expand Down