Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quality of Life Improvements #666

Merged
merged 5 commits into from
May 30, 2024
Merged

Quality of Life Improvements #666

merged 5 commits into from
May 30, 2024

Conversation

avik-pal
Copy link
Member

  1. Fixes Definition and implementation of 'Loss' in Linear Regression Tutorial "Julia & Lux for the Uninitiated" #664
  2. Takes the improvements from Update & Rewrite the DDIM example #661
  3. Optimisers.adjust! and Optimisers.adjust can be directly applied to TrainState.
  4. StatefulLuxLayer has pretty printing
  5. StatefulLuxLayer is compatible with Adapt, so gpu_device() / cpu_device() can be directly applied to them.
  6. SimpleChains model printing now displays the internal model

@avik-pal avik-pal force-pushed the ap/qol branch 2 times, most recently from 4bce458 to 8aabceb Compare May 28, 2024 03:55
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: b3da5d3 Previous: 711f0d4 Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3654.25 ns 3709.5 ns 0.99
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7294.416666666667 ns 7380.5 ns 0.99
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 21019 ns 20849 ns 1.01
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9694.2 ns 9834.6 ns 0.99
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8916 ns 9004.6 ns 0.99
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4522.25 ns 4457.125 ns 1.01
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1173.889705882353 ns 1180.7964285714286 ns 0.99
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1121.8471337579617 ns 1127.3529411764705 ns 1.00
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1176.2176870748299 ns 1200.623076923077 ns 0.98
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1795.9435483870968 ns 1807.9791666666667 ns 0.99
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.82695035460992 ns 179.3215796897038 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17303 ns 17473 ns 0.99
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 17202 ns 17323 ns 0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 37290 ns 37320 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28523 ns 28403 ns 1.00
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 20062.5 ns 19697 ns 1.02
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17163 ns 16961 ns 1.01
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4379.571428571428 ns 4348.142857142857 ns 1.01
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3891 ns 3891 ns 1
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3967.375 ns 3971.125 ns 1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4935 ns 4864.857142857143 ns 1.01
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1665.2 ns 1651.1 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 49862783 ns 39357045 ns 1.27
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57951957.5 ns 58382754 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 81929525 ns 70740119.5 ns 1.16
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 91773047 ns 89445008.5 ns 1.03
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 74541657 ns 73099358 ns 1.02
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11654050 ns 11995172 ns 0.97
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 8462537.5 ns 8439129 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7028178 ns 7046254.5 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 6989615 ns 6995964 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 12029541 ns 10537813 ns 1.14
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6382755 ns 6397526.5 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 729188104 ns 709254300.5 ns 1.03
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2890332395 ns 2856579777 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 159364475 ns 146002369.5 ns 1.09
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 883959196 ns 772481590 ns 1.14
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3052476601 ns 2557737194 ns 1.19
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 208368389 ns 208875665 ns 1.00
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 752733522 ns 659452651.5 ns 1.14
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 3018991899 ns 2998638163 ns 1.01
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 140793341 ns 127137322 ns 1.11
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 173924803.5 ns 197330158 ns 0.88
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 647806082 ns 644170332 ns 1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 45608327.5 ns 35399348.5 ns 1.29
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 164885232 ns 165807519 ns 0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 639563005 ns 641228971 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 29964783 ns 30401122 ns 0.99
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 201013882.5 ns 189053429 ns 1.06
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 832045776.5 ns 740776098 ns 1.12
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 36162677 ns 38366202 ns 0.94
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1314714489.5 ns 1222693120 ns 1.08
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1868274177.5 ns 1869050800 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2073536593 ns 2084688657 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2574973290 ns 2406711056 ns 1.07
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1918297829 ns 1866369263.5 ns 1.03
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 351044004 ns 348144654 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 323271351.5 ns 320766566 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 321347820.5 ns 321154062 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 393478552.5 ns 425548155 ns 0.92
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11868524 ns 11906305 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17930742.5 ns 18033773.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19172899 ns 19282551 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23881234 ns 24037397.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17959870.5 ns 18076579 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1156183 ns 1142703 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 2519708 ns 2526639 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2047810.5 ns 2044313 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2024605.5 ns 2028618.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2069556 ns 2071827 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 197940 ns 195835 ns 1.01
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 292857 ns 291794 ns 1.00
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 267810 ns 265155 ns 1.01
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 367497 ns 360262.5 ns 1.02
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 410117 ns 406299 ns 1.01
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 275875.5 ns 274753 ns 1.00
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 408032 ns 399305.5 ns 1.02
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83927 ns 83606 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 82084 ns 82545 ns 0.99
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 82965 ns 84985 ns 0.98
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86952 ns 86852 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104435 ns 104315 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 205650770 ns 204251813 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 327642671.5 ns 329620520.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 433438644 ns 400412739 ns 1.08
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 471778300.5 ns 437793364 ns 1.08
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 381189262.5 ns 382057169 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 327065896 ns 341321756 ns 0.96
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 51282963.5 ns 51663673 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 43554508 ns 44079076.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43355230.5 ns 43860337 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 56772107 ns 62124168 ns 0.91
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28220531 ns 29198781 ns 0.97
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18780773 ns 19867740 ns 0.95
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19420566 ns 19636061.5 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23239794.5 ns 23495500.5 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 23948241.5 ns 24241894 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19487271 ns 19729530 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6507789 ns 6555896 ns 0.99
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6489038 ns 6592458 ns 0.98
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6451475 ns 6539830 ns 0.99
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6472038 ns 6519358 ns 0.99

This comment was automatically generated by workflow using github-action-benchmark.

Copy link

codecov bot commented May 28, 2024

Codecov Report

Attention: Patch coverage is 78.72340% with 10 lines in your changes are missing coverage. Please review.

Project coverage is 86.60%. Comparing base (ca2c635) to head (b3da5d3).
Report is 2 commits behind head on main.

Files Patch % Lines
src/layers/display.jl 50.00% 5 Missing ⚠️
src/helpers/stateful.jl 81.25% 3 Missing ⚠️
src/layers/extension.jl 50.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #666      +/-   ##
==========================================
- Coverage   86.87%   86.60%   -0.27%     
==========================================
  Files          50       50              
  Lines        2491     2524      +33     
==========================================
+ Hits         2164     2186      +22     
- Misses        327      338      +11     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@avik-pal avik-pal merged commit baa2092 into main May 30, 2024
70 of 75 checks passed
@avik-pal avik-pal deleted the ap/qol branch May 30, 2024 00:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Definition and implementation of 'Loss' in Linear Regression Tutorial "Julia & Lux for the Uninitiated"
1 participant