Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: expand nested AD support to TrackedArrays #805

Closed
wants to merge 1 commit into from

Conversation

avik-pal
Copy link
Member

TODOs

  • Initial Support
  • Expand the tests
  • Proper support for parameter gradients
  • Documentation updates

@avik-pal avik-pal changed the base branch from main to ap/testing_update July 29, 2024 05:30
Copy link

codecov bot commented Jul 29, 2024

Codecov Report

Attention: Patch coverage is 7.69231% with 12 lines in your changes missing coverage. Please review.

Project coverage is 81.54%. Comparing base (5f2cd25) to head (8dfd0a4).
Report is 6 commits behind head on main.

Files Patch % Lines
src/helpers/nested_ad.jl 0.00% 9 Missing ⚠️
src/forwarddiff/nested_ad.jl 0.00% 2 Missing ⚠️
ext/LuxZygoteExt/LuxZygoteExt.jl 0.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (5f2cd25) and HEAD (8dfd0a4). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (5f2cd25) HEAD (8dfd0a4)
43 37
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #805       +/-   ##
===========================================
- Coverage   95.94%   81.54%   -14.40%     
===========================================
  Files          54       54               
  Lines        2834     2850       +16     
===========================================
- Hits         2719     2324      -395     
- Misses        115      526      +411     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Base automatically changed from ap/testing_update to main July 29, 2024 06:20
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 666e728 Previous: 59402fe Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3649.375 ns 3654.375 ns 1.00
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 6742.5 ns 7423.833333333333 ns 0.91
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 21099 ns 20889 ns 1.01
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9960.6 ns 9879 ns 1.01
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 9173.2 ns 9033 ns 1.02
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4532.25 ns 4503.375 ns 1.01
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 4968 ns 4685 ns 1.06
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1039.1761006289307 ns 1129.4184397163122 ns 0.92
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1058.9378881987577 ns 1183.8444444444444 ns 0.89
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1834.304347826087 ns 1806.017543859649 ns 1.02
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.6586741889986 ns 180.11344537815125 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17302 ns 17302 ns 1
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 12859.5 ns 17042 ns 0.75
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 40155 ns 36989 ns 1.09
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29455 ns 29144 ns 1.01
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 20128 ns 21400 ns 0.94
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17092 ns 17453 ns 0.98
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 25708 ns 25457 ns 1.01
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 1433.75 ns 3867.8125 ns 0.37
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 1502.8 ns 3941.125 ns 0.38
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 5062.428571428572 ns 4884.928571428572 ns 1.04
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1652.1 ns 1651.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 41281366 ns 38881796 ns 1.06
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 81650131 ns 58827311 ns 1.39
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 74702412 ns 67404189 ns 1.11
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 83501279 ns 90726308 ns 0.92
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 76199365 ns 72685526 ns 1.05
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 12379610 ns 12054419 ns 1.03
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 88476051 ns 88616409 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 13108690.5 ns 7678423 ns 1.71
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 13082438 ns 7587021 ns 1.72
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 12352324 ns 10033851 ns 1.23
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6392433 ns 6401934.5 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 956998749.5 ns 695594050 ns 1.38
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 3132990623 ns 2591199265 ns 1.21
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 135905941 ns 145450054 ns 0.93
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 844957836.5 ns 856061822 ns 0.99
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 2619641286 ns 2884249763 ns 0.91
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 200574138.5 ns 246325046.5 ns 0.81
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 779144261.5 ns 714210079 ns 1.09
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 3378147776 ns 2711805409 ns 1.25
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 145184189 ns 127679277 ns 1.14
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 313742836 ns 172995580.5 ns 1.81
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 958140791 ns 651366871 ns 1.47
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 44772184 ns 45176285.5 ns 0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 314911863 ns 164593550 ns 1.91
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 952191410.5 ns 640463071 ns 1.49
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 43484498 ns 30093452 ns 1.44
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 202896030 ns 204835624 ns 0.99
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 768590524 ns 720482648.5 ns 1.07
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 35973525.5 ns 35672885 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1322910338.5 ns 1309190313 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1883620719 ns 1883793155 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2076680907.5 ns 2353304817 ns 0.88
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2417173032 ns 2438276652 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1812230966 ns 1867266298 ns 0.97
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 1882161961 ns 2034616539 ns 0.93
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 310347613.5 ns 335394624.5 ns 0.93
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 314287672 ns 332856663 ns 0.94
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 348275884 ns 459008363 ns 0.76
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11909521 ns 11905264.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 28758051 ns 18205658.5 ns 1.58
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19213625 ns 19288136 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23935716 ns 23962152.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17890192 ns 18014189 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1147539 ns 1167993 ns 0.98
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 30336085 ns 23074801.5 ns 1.31
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 3833100.5 ns 2441973 ns 1.57
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 3824597 ns 2232336 ns 1.71
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2080426 ns 2074976 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 198431 ns 199804 ns 0.99
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 296484 ns 293199.5 ns 1.01
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 212307 ns 266209 ns 0.80
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 378467 ns 370003 ns 1.02
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 419423 ns 411882 ns 1.02
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 282969 ns 275706.5 ns 1.03
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 413963 ns 409858 ns 1.01
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 399286 ns 396372 ns 1.01
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 63989 ns 81443 ns 0.79
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 64590.5 ns 82123 ns 0.79
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 89928 ns 87013 ns 1.03
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104475 ns 104516 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 201561304 ns 197115275 ns 1.02
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 337961507 ns 331248547.5 ns 1.02
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 416899844 ns 427260324 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 494023070.5 ns 483420905 ns 1.02
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 372638935 ns 386517355 ns 0.96
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 350916068 ns 338431242 ns 1.04
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 436991512 ns 469112435 ns 0.93
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 49979608 ns 47409570 ns 1.05
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 50373882 ns 46895813 ns 1.07
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 50686798.5 ns 56448294.5 ns 0.90
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28140781 ns 28438815 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18964434 ns 19130483.5 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 23836283.5 ns 19606525 ns 1.22
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23235237 ns 23552054 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24255620 ns 24232581.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19727463 ns 19724625 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 21023299 ns 20980827 ns 1.00
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 7570364 ns 6536465 ns 1.16
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 7902635 ns 6533604 ns 1.21
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6560144 ns 6545116 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Tracker and ReverseDiff both reuse the ChainRules rrule mechanism
and use Zygote to differentiate the nested AD part. If Zygote
isn't loaded it throws an error instructing users to install and
load Zygote.
@avik-pal avik-pal marked this pull request as draft August 2, 2024 03:00
@avik-pal
Copy link
Member Author

Doing this on a case by case basis. see #848

@avik-pal avik-pal closed this Aug 17, 2024
@avik-pal avik-pal deleted the ap/nested_ad_tracker branch August 17, 2024 17:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant