Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store the optimizer in TrainState #731

Merged
merged 1 commit into from
Jun 26, 2024
Merged

Store the optimizer in TrainState #731

merged 1 commit into from
Jun 26, 2024

Conversation

avik-pal
Copy link
Member

Needed for a workaround in #673

Copy link

codecov bot commented Jun 26, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.08%. Comparing base (39cf6ca) to head (c39c344).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #731      +/-   ##
==========================================
+ Coverage   96.89%   97.08%   +0.18%     
==========================================
  Files          53       53              
  Lines        2708     2712       +4     
==========================================
+ Hits         2624     2633       +9     
+ Misses         84       79       -5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: c39c344 Previous: 39cf6ca Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3618 ns 3724.5 ns 0.97
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7153.5 ns 7263.5 ns 0.98
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 21160 ns 21721 ns 0.97
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9676.2 ns 9884.4 ns 0.98
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8989.25 ns 9075 ns 0.99
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4447.125 ns 4489.625 ns 0.99
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1177.5074626865671 ns 1169.2426470588234 ns 1.01
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1182.3700787401574 ns 1171.953125 ns 1.01
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1162.2397260273972 ns 1192.2325581395348 ns 0.97
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1781.725806451613 ns 1789.8909090909092 ns 1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.89985895627643 ns 180.28094575799722 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17222 ns 17353 ns 0.99
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16872 ns 16861 ns 1.00
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 38862 ns 37440 ns 1.04
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29195 ns 29165 ns 1.00
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 21260 ns 21460 ns 0.99
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17423 ns 17342 ns 1.00
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4328.142857142857 ns 4313.714285714285 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3882.25 ns 3863.375 ns 1.00
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3967.5 ns 3919.875 ns 1.01
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4934.857142857143 ns 4787.571428571428 ns 1.03
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1657.1 ns 1652.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 42385405 ns 48648775 ns 0.87
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 58300293 ns 58411617.5 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 83663155 ns 81123660.5 ns 1.03
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 93128762.5 ns 103601977.5 ns 0.90
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 79070741.5 ns 89917151 ns 0.88
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 12564244 ns 11791492 ns 1.07
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 18372226 ns 17772884 ns 1.03
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7121828 ns 7008179 ns 1.02
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 7091531 ns 7033265.5 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 12380297 ns 11422124 ns 1.08
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6432398 ns 6390720 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 754331676 ns 720403839 ns 1.05
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2589500980 ns 2579251187 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 150687753.5 ns 131924729 ns 1.14
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 786620041 ns 900116323 ns 0.87
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3272539737 ns 3346053799 ns 0.98
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 207105958 ns 201398890.5 ns 1.03
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 733733888.5 ns 746912708 ns 0.98
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2974925512 ns 2864487060 ns 1.04
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 132546351 ns 146243031 ns 0.91
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 182541170 ns 173248440.5 ns 1.05
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 652642375 ns 642497244 ns 1.02
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 41983085 ns 34221042 ns 1.23
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 165695138.5 ns 163583313 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 643679219 ns 638831188 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 30342826 ns 29493369 ns 1.03
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 203230011 ns 188574057.5 ns 1.08
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 870376038 ns 762725894.5 ns 1.14
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 36647314 ns 37965929 ns 0.97
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1289843718.5 ns 1313433848 ns 0.98
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1910259095.5 ns 1856950599 ns 1.03
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2551998139 ns 2380861080 ns 1.07
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2644757778 ns 2536885397 ns 1.04
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 2011565890 ns 1911900503.5 ns 1.05
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 569378507 ns 559300306.5 ns 1.02
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 334964718.5 ns 316445254 ns 1.06
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 330079129 ns 315455657 ns 1.05
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 412502253 ns 374643217 ns 1.10
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 12021359 ns 11689152.5 ns 1.03
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 18110467 ns 17845202 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19444746 ns 19134222 ns 1.02
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23984518.5 ns 23730204 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17880672 ns 17836748 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1153170 ns 1152863 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 5906662 ns 5748354 ns 1.03
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2061110 ns 2035791.5 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2047945 ns 2018609 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2086438 ns 2060524 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 211607 ns 195646 ns 1.08
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 292668 ns 292185 ns 1.00
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 265387 ns 263102 ns 1.01
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 365064 ns 361645 ns 1.01
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 408409.5 ns 404565 ns 1.01
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 275305.5 ns 272068 ns 1.01
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 405450 ns 404005 ns 1.00
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83346 ns 83025 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81643 ns 81051 ns 1.01
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 81362 ns 81031 ns 1.00
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86602 ns 86181 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104375 ns 104525 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 209220777 ns 204779961 ns 1.02
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 326450312.5 ns 323960068 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 430564282.5 ns 433985465 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 459581385.5 ns 476340195.5 ns 0.96
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 388198511 ns 384506130 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 327290692.5 ns 311083029 ns 1.05
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 103404306 ns 100034034.5 ns 1.03
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 43902250 ns 43400919.5 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43769121 ns 43185736.5 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 57545724 ns 69730458 ns 0.83
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28435984 ns 27868008.5 ns 1.02
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18888675 ns 18761703 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19545816 ns 19386864.5 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23395815 ns 23049373 ns 1.02
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24170736 ns 23964828 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19627742.5 ns 19519177 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6536932 ns 6475391.5 ns 1.01
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6549410 ns 6473828 ns 1.01
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6514991 ns 6483596.5 ns 1.00
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6514691 ns 6480540.5 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal avik-pal merged commit 201bcb4 into main Jun 26, 2024
53 checks passed
@avik-pal avik-pal deleted the ap/store_optimizer branch June 26, 2024 04:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant