Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weโ€™ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training reproducibility improvements #8213

Merged
merged 34 commits into from
Jul 7, 2022

Conversation

AyushExel
Copy link
Contributor

@AyushExel AyushExel commented Jun 14, 2022

Followed suggestions from:
pytorch/pytorch#7068 (comment)
https://www.mldawn.com/reproducibility-in-pytorch/

๐Ÿ› ๏ธ PR Summary

Made with โค๏ธ by Ultralytics Actions

๐ŸŒŸ Summary

Introduced a global training seed option for enhanced reproducibility.

๐Ÿ“Š Key Changes

  • Added a new --seed command-line argument to specify the global training seed.
  • Modified init_seeds method to accept a deterministic parameter and implemented deterministic behavior when activated.
  • Updated init_seeds to use PyTorch's use_deterministic_algorithms() and set the CUBLAS_WORKSPACE_CONFIG environment variable for PyTorch versions >= 1.12.0.

๐ŸŽฏ Purpose & Impact

  • ๐Ÿ‘จโ€๐Ÿ”ฌ Enhanced Reproducibility: The changes allow for more reproducible training results, which is particularly important for experiments and comparisons.
  • ๐Ÿงฎ Consistent Training Behaviour: Users can expect consistent model performance when retraining with the same seed, reducing variability due to random processes.
  • ๐Ÿค– Developer Convenience: By introducing a command-line argument for seeding, developers have an easier time setting and managing the seeds within their training scripts.

@glenn-jocher
Copy link
Member

@AyushExel thanks for the PR!

Can you please provide before and after results, i.e. 3 runs from master and 3 runs from PR? The scenario can be small, i.e. COCO128 YOLOv5s 30 epochs, but it's important to compare changes to the current baseline. Thanks!

@AyushExel
Copy link
Contributor Author

Screenshot from 2022-06-15 16-21-47
@glenn-jocher

  • All red lines are master runs and blue are runs from this branch.
  • The metrics section shows similar variance for both as the scores are very small, only 3rd decimal places
  • But for train and val metrics this branch shows much less variance than master.
    Again, due to small scale of the dataset and small numerical values involved, this test needs to be verified using a larger dataset.
    Dashboard

@glenn-jocher
Copy link
Member

@AyushExel got it, perfect!

I will check out this branch and run full YOLOv5s COCO trainings on the 8 GPUs today.

@AyushExel
Copy link
Contributor Author

@glenn-jocher nice. You also have the same test for master branch already right?

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 15, 2022

@AyushExel yes, these are the differences between min and max [email protected]:0.95:

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 15, 2022

@AyushExel seems to show identical variation to master at epoch 8 (about 0.4%), so seems like no change in randomness.

What happened to the torch.use_deterministic_algorithms() that I sugested?

@glenn-jocher
Copy link
Member

@AyushExel also your dataloader init function seems to be lacking python and torch seed inits as in this example: https://discuss.pytorch.org/t/reproducibility-with-all-the-bells-and-whistles/81097

@AyushExel
Copy link
Contributor Author

AyushExel commented Jun 15, 2022

@glenn-jocher The dataloader init_fn only requires np random seed as mentioned in official pytorch issues here and here

The torch.use_deterministic_algorithms() is not exception safe. There are many operations that'll just throw runtime error when this is enabled. Also to work for CUDA 10.2 and above correctly some environ variables need to be set or it'll cause runtime exceptions. More details here in the last section https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
Enabling this seems way too risky

@AyushExel
Copy link
Contributor Author

@glenn-jocher Also after reading a lot of discussions on various platforms( github, kaggle, pytorch discourse) I haven't found anyone who has actually been able to accurately reproduce their large model experiments. All solutions are just there to reduce the variance
It might happen due to the numerical instability of due to the estimation of floating point gradients.

@glenn-jocher
Copy link
Member

@AyushExel ok I'm going to cancel this training and try new experiments:

  1. Train with workers=0
  2. Train with AMP disabled (needs new branch)
  3. We can try the additional python and torch seeds to see if they help (please update this PR's init_fcn for this)

@AyushExel
Copy link
Contributor Author

@glenn-jocher I'll do the 3rd on this branch in some time.

@glenn-jocher
Copy link
Member

@AyushExel ok got it! --workers 0 experiment started in new project, tracking results in same comment #8213 (comment). Each epoch there takes 30 min so we'll have the epoch 10 results in about 5 hours.

The good news is the randomness at epoch 10 is a great benchmark, no need to wait 300 epochs.

@AyushExel
Copy link
Contributor Author

AyushExel commented Jun 15, 2022

@glenn-jocher I was just testing torch.use_deterministic_alg on my branch. The training ran successfully but there's one operation post training that throws runtime error. I'll test more to see if its actually deterministic. If so, we can change the implementation of the operation that throws error. If not, let's leave it alone.

 File "/home/yolov5/utils/torch_utils.py", line 205, in fuse_conv_and_bn
    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility

@AyushExel
Copy link
Contributor Author

AyushExel commented Jun 15, 2022

@glenn-jocher keep an eye on https://wandb.ai/cayush/use_deter?workspace=user-cayush
Training yolov5n from scratch for 50 epochs with deteministic_alg enabled

@AyushExel
Copy link
Contributor Author

AyushExel commented Jun 15, 2022

Screenshot from 2022-06-15 20-33-04

@glenn-jocher Okay so I've set the warn_only flag of torch.use_deterministic_alg to True which means it won't throw error after training. But this also means that if something goes wrong with reproduciblity ,it'll fail silently. We'll need to keep that in mind when creating new operations.
That aside, these are the best results that I've got. All metrics are coinciding perfectly
Dash

EDIT: It seems like additional seed inn dataloader init fn is not required. So leaving it as it is right now. It only affects DDP mode which I can't test locally

@AyushExel
Copy link
Contributor Author

@glenn-jocher You'll need to run these tests:

  • One same as yesterday. train 8 yolov5s on 8GPUs from this branch. These metrics should all perfectly match
  • if the above is successful, also test reproducibility in DDP mode. train 4 models in DDP across 8 GPUs( 2 for each training)
    The 2nd test will confirm if we need any changes in init_fn method of the dataloader. I've run tests on multiple devices and my results for single worker training is perfectly reproducible( to the 4th decimal place)

@glenn-jocher
Copy link
Member

@AyushExel got it. Running 8 YOLOv5s now at https://wandb.ai/glenn-jocher/test-reproduce-pr2

@AyushExel
Copy link
Contributor Author

@glenn-jocher great. How long does 1 epoch usually take? From the benchmark runs you posted above:

@AyushExel
Copy link
Contributor Author

AyushExel commented Jun 16, 2022

@glenn-jocher
good thing:

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 17, 2022

@AyushExel yes I see this also. Losses are all identical but mAP is 0. Usually when mAP is 0 it's due to AMP/CUDA/Windows/Conda issues, but I've recently added AMP checks and these are passing for PR trainings in https://wandb.ai/glenn-jocher/test-reproduce-pr3

AMP checks run inference on pretrained model (or YOLOv5n downloaded model if not pretrained model) to verify that AMP inference and default inference produce similar results. This was added in #7917 and improved in #7937

@AyushExel
Copy link
Contributor Author

@glenn-jocher No responses on the pytorch forum issue. I'm trying to debug this using pdb. Hopefully the bug is occurring during the calculation of maps with deterministic algorithms. I'll verify if the model is actually learning anything or not by plotting results in each epoch.

@AyushExel
Copy link
Contributor Author

AyushExel commented Jun 17, 2022

@glenn-jocher The error is with calculation. I plotted BBoxDebugger for 1st epoch in VOC training and most objects are detected correctly so map shouldn't be 0. https://wandb.ai/cayush/use_deter_s/
EDIT:
Okay found the issue. This line is always false when deterministic alg is set -> https://github.com/ultralytics/yolov5/blob/master/val.py#L265
This happens because stats[0].any() returns false. When not using deterministic alg, it returns true. So the bug is somewhere in the process_batch function.
The iou values on the first epoch are very low when deterministic_alg is set.
EDIT2:
Okay I tried a lot of tracebacks. No idea where things are going wrong

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 17, 2022

@AyushExel I overlaid a master run against current PR: train losses, val losses, learning rates are all identical, but all metrics are zero. Very strange. Obviously the latest commit 254d379 caused this. Looking at the commit it has two changes, so let's try to isolate one change at a time to identify the cause. I'll comment out one line and retry a new training.

Screenshot 2022-06-17 at 13 42 00

@glenn-jocher
Copy link
Member

@glenn-jocher The error is with calculation. I plotted BBoxDebugger for 1st epoch in VOC training and most objects are detected correctly so map shouldn't be 0. https://wandb.ai/cayush/use_deter_s/ EDIT: Okay found the issue. This line is always false when deterministic alg is set -> https://github.com/ultralytics/yolov5/blob/master/val.py#L265 This happens because stats[0].any() returns false. When not using deterministic alg, it returns true. So the bug is somewhere in the process_batch function. The iou values on the first epoch are very low when deterministic_alg is set. EDIT2: Okay I tried a lot of tracebacks. No idea where things are going wrong

I see what you're saying here. So this is good news, it means the models are actually learning and are identical, it's just the validation that seems problematic. But the validation is always deterministic anyways, it never varies so maybe we can set flags to enable/disable deterministic mode in val.py as a quick fix.

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 29, 2022

@UnglvKitDe okay got it. Thanks for pointing out. Do let us know about your findings. @glenn-jocher I think we need to run another test with torch 1.12, this time without resetting the deterministic operation after every epoch

@UnglvKitDe thanks for the feedback! I've made updates to only run the command once if torch 1.12 is installed. torch < 1.12 we'll leave alone.

EDIT: Will run a new training today with these settings.

@UnglvKitDe
Copy link
Contributor

@UnglvKitDe okay got it. Thanks for pointing out. Do let us know about your findings. @glenn-jocher I think we need to run another test with torch 1.12, this time without resetting the deterministic operation after every epoch

@UnglvKitDe thanks for the feedback! I've made updates to only run the command once if torch 1.12 is installed. torch < 1.12 we'll leave alone.

EDIT: Will run a new training today with these settings.

@glenn-jocher @AyushExel I did 5 run with coco128. In one of 5 runs I get the 0 results again. A similar picture on my custom data (1 of 8 has the 0 problem again). Very strange. I have set up a clean conda installation with torch 1.12 and cuda 11.6.
results

@glenn-jocher
Copy link
Member

@UnglvKitDe this is not the zero mAP problem. zero mAP means zero mAP at all times. In your training your validation loses are unstable and increasing leading to logically low mAP.

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 29, 2022

Tested PR in Colab with 1.12. Looks good, all 3 identical and high mAP.

!git clone https://github.com/AyushExel/yolov5 -b init_seeds  # clone
%cd yolov5
%pip install -qr requirements.txt torch==1.12 torchvision==0.13  # install

import torch
import utils
display = utils.notebook_init()  # checks

# Train YOLOv5s on COCO128 for 10 epochs
!python train.py --img 640 --batch 16 --epochs 10 --data coco128.yaml --weights yolov5s.pt --cache --seed 0
!python train.py --img 640 --batch 16 --epochs 10 --data coco128.yaml --weights yolov5s.pt --cache --seed 0
!python train.py --img 640 --batch 16 --epochs 10 --data coco128.yaml --weights yolov5s.pt --cache --seed 0

Screen Shot 2022-06-29 at 4 57 23 PM

@glenn-jocher
Copy link
Member

glenn-jocher commented Jun 29, 2022

Testing vs master on COCO in https://wandb.ai/glenn-jocher/test-reproduce-pr4

EDIT: Unable to test on Multi-GPU systems per torch 1.12 YOLOv5 bug in #8395

@UnglvKitDe
Copy link
Contributor

@UnglvKitDe this is not the zero mAP problem. zero mAP means zero mAP at all times. In your training your validation loses are unstable and increasing leading to logically low mAP.

@glenn-jocher Then we talked about other issues. When I used use_deterministic_algorithms, I got the problem described above (AP 0 at the end).

@glenn-jocher
Copy link
Member

W&B trainings have been cancelled because unable to test on Multi-GPU systems per torch 1.12 YOLOv5 bug in #8395

@UnglvKitDe
Copy link
Contributor

@AyushExel @glenn-jocher With torch 1.12 there is a issue with multi gpu training when I insert the reproducibility as described in the doc. For some reason, CUDA runs out of memory. On the master it works without problems (~5.1/12 GB VRAM).

@UnglvKitDe
Copy link
Contributor

UnglvKitDe commented Jun 30, 2022

ok, so i tried debugging this today, but i don't understand why the problem occurs. if i use a different number of workers it works. unfortunately ( at the moment ) i can't recreate it with any public record.

@UnglvKitDe okay got it. Thanks for pointing out. Do let us know about your findings. @glenn-jocher I think we need to run another test with torch 1.12, this time without resetting the deterministic operation after every epoch

@UnglvKitDe thanks for the feedback! I've made updates to only run the command once if torch 1.12 is installed. torch < 1.12 we'll leave alone.
EDIT: Will run a new training today with these settings.

@glenn-jocher @AyushExel I did 5 run with coco128. In one of 5 runs I get the 0 results again. A similar picture on my custom data (1 of 8 has the 0 problem again). Very strange. I have set up a clean conda installation with torch 1.12 and cuda 11.6. results

ok, so i tried debugging the above problem today, but i don't understand why the problem occurs. if i use a different number of workers it works. unfortunately (as of now) i can't recreate it with any public dataset. Very strange. @glenn-jocher Have you ever seen such a problem?

@glenn-jocher
Copy link
Member

@UnglvKitDe it's not uncommon for gradient/training instabilities to lead to higher losses and diverged results. This is just a fact of life with nonlinear optimization problems.

The reproducibility part we are trying to work on with this PR of course though.

@glenn-jocher
Copy link
Member

@AyushExel I think this PR is good to merge. I added a deterministic=False argument to init_seeds()

@glenn-jocher glenn-jocher merged commit 27d831b into ultralytics:master Jul 7, 2022
@glenn-jocher
Copy link
Member

@AyushExel PR is merged! The new deterministic policy is that init_seeds() defaults to False but we pass true in train.py.

init_seeds(opt.seed + 1 + RANK, deterministic=True)

I also added init_seeds to classifier.py and observed deterministic behavior without having to set deterministic=True, but also test it with True and saw no errors (strange). Anyway I think we are done here and can move on to other things!

Shivvrat pushed a commit to Shivvrat/epic-yolov5 that referenced this pull request Jul 12, 2022
* attempt at reproducibility

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use deterministic algs

* fix everything :)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert dataloader changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* process_batch as np

* remove newline

* Remove dataloader init fcn

* Update val.py

* Update train.py

* revert additional changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

* Add --seed arg

* Update general.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

* Update train.py

* Update val.py

* Update train.py

* Update general.py

* Update general.py

* Add deterministic argument to init_seeds()

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>
@konioy konioy mentioned this pull request Jul 21, 2022
1 task
@glenn-jocher glenn-jocher removed the TODO label Jul 30, 2022
ctjanuhowski pushed a commit to ctjanuhowski/yolov5 that referenced this pull request Sep 8, 2022
* attempt at reproducibility

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use deterministic algs

* fix everything :)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert dataloader changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* process_batch as np

* remove newline

* Remove dataloader init fcn

* Update val.py

* Update train.py

* revert additional changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

* Add --seed arg

* Update general.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

* Update train.py

* Update val.py

* Update train.py

* Update general.py

* Update general.py

* Add deterministic argument to init_seeds()

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>
@lijiajun3029
Copy link

good job

@glenn-jocher
Copy link
Member

@lijiajun3029 thank you! ๐Ÿ™ This is a team effort and your valuable feedback and testing have been instrumental in improving YOLOv5. We're always here if you have more questions or need further assistance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants