From e13899f8b5d36299a86392c16a19c72b0cc01305 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 15 May 2024 04:13:00 -0700 Subject: [PATCH 1/3] Backport Enzyme changes (#2375) Co-authored-by: Valentin Churavy Co-authored-by: Tim Besard --- Project.toml | 5 + ext/EnzymeCoreExt.jl | 196 +++++++++++++++++++++++++++++++++++++++ src/initialization.jl | 3 + test/Project.toml | 3 + test/libraries/enzyme.jl | 56 +++++++++++ test/setup.jl | 3 + 6 files changed, 266 insertions(+) create mode 100644 ext/EnzymeCoreExt.jl create mode 100644 test/libraries/enzyme.jl diff --git a/Project.toml b/Project.toml index c93a53b43a..35c19d89e6 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ CUDA_Runtime_Discovery = "1af6417a-86b4-443c-805f-a4643ffb695f" CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" @@ -37,10 +38,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] ChainRulesCoreExt = "ChainRulesCore" +EnzymeCoreExt = "EnzymeCore" SpecialFunctionsExt = "SpecialFunctions" [compat] @@ -55,6 +58,7 @@ ChainRulesCore = "1" Crayons = "4" DataFrames = "1" ExprTools = "0.1" +EnzymeCore = "0.7.1" GPUArrays = "10.0.1" GPUCompiler = "0.24, 0.25, 0.26" KernelAbstractions = "0.9.2" @@ -81,4 +85,5 @@ julia = "1.8" [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/ext/EnzymeCoreExt.jl b/ext/EnzymeCoreExt.jl new file mode 100644 index 0000000000..4ca3c77614 --- /dev/null +++ b/ext/EnzymeCoreExt.jl @@ -0,0 +1,196 @@ +# compatibility with EnzymeCore + +module EnzymeCoreExt + +using CUDA +import CUDA: GPUCompiler, CUDABackend + +if isdefined(Base, :get_extension) + using EnzymeCore + using EnzymeCore.EnzymeRules +else + using ..EnzymeCore + using ..EnzymeCore.EnzymeRules +end + +function EnzymeCore.compiler_job_from_backend(::CUDABackend, @nospecialize(F::Type), @nospecialize(TT::Type)) + mi = GPUCompiler.methodinstance(F, TT) + return GPUCompiler.CompilerJob(mi, CUDA.compiler_config(CUDA.device())) +end + +function metaf(fn, args::Vararg{Any, N}) where N + EnzymeCore.autodiff_deferred(Forward, fn, Const, args...) + nothing +end + +function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)}, + ::Type{<:Duplicated}, f::Const{F}, + tt::Const{TT}; kwargs...) where {F,TT} + res = ofn.val(f.val, tt.val; kwargs...) + return Duplicated(res, res) +end + +function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)}, + ::Type{BatchDuplicated{T,N}}, f::Const{F}, + tt::Const{TT}; kwargs...) where {F,TT,T,N} + res = ofn.val(f.val, tt.val; kwargs...) + return BatchDuplicated(res, ntuple(Val(N)) do _ + res + end) +end + +function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)}, + ::Type{RT}, x::IT) where {RT, IT} + if RT <: Duplicated + return Duplicated(ofn.val(x.val), ofn.val(x.dval)) + elseif RT <: Const + return ofn.val(x.val) + elseif RT <: DuplicatedNoNeed + return ofn.val(x.val) + else + tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i + Base.@_inline_meta + ofn.val(x.dval[i]) + end + if RT <: BatchDuplicated + return BatchDuplicated(ofv.val(x.val), tup) + else + return tup + end + end +end + +function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(synchronize)}, + ::Type{RT}, args::NTuple{N, EnzymeCore.Annotation}; kwargs...) where {RT, N} + pargs = ntuple(Val(N)) do i + Base.@_inline_meta + args.val + end + res = ofn.val(pargs...; kwargs...) + + if RT <: Duplicated + return Duplicated(res, res) + elseif RT <: Const + return res + elseif RT <: DuplicatedNoNeed + return res + else + tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i + Base.@_inline_meta + res + end + if RT <: BatchDuplicated + return BatchDuplicated(res, tup) + else + return tup + end + end +end + +function EnzymeCore.EnzymeRules.forward(ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}}, + ::Type{Const{Nothing}}, args...; + kwargs...) where {F,TT} + + GC.@preserve args begin + args = ((cudaconvert(a) for a in args)...,) + T2 = (F, (typeof(a) for a in args)...) + TT2 = Tuple{T2...} + cuf = cufunction(metaf, TT2) + res = cuf(ofn.val.f, args...; kwargs...) + end + + return nothing +end + +function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes} + if A isa Const || A isa Duplicated || A isa BatchDuplicated + ofn.val(A.val, x.val) + end + + if A isa Duplicated || A isa DuplicatedNoNeed + ofn.val(A.dval, x isa Const ? zero(T) : x.dval) + elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.batch_width(A))) do i + Base.@_inline_meta + ofn.val(A.dval[i], x isa Const ? zero(T) : x.dval[i]) + nothing + end + end + + if RT <: Duplicated + return A + elseif RT <: Const + return A.val + elseif RT <: DuplicatedNoNeed + return A.dval + elseif RT <: BatchDuplicated + return A + else + return A.dval + end +end + + +function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes} + if A isa Const || A isa Duplicated || A isa BatchDuplicated + ofn.val(A.val, x.val) + end + + if !(T <: AbstractFloat) + if A isa Duplicated || A isa DuplicatedNoNeed + ofn.val(A.dval, zero(T)) + elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.batch_width(A))) do i + Base.@_inline_meta + ofn.val(A.dval[i], zero(T)) + nothing + end + end + end + + primal = if EnzymeRules.needs_primal(config) + A.val + else + nothing + end + + shadow = if EnzymeRules.needs_shadow(config) + A.dval + else + nothing + end + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(Base.fill!)}, ::Type{RT}, tape, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x::EnzymeCore.Annotation{T2}) where {RT, T <: CUDA.MemsetCompatTypes, T2} + dx = if x isa Active + if A isa Duplicated || A isa DuplicatedNoNeed + T2(sum(A.dval)) + elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.batch_width(A))) do i + Base.@_inline_meta + T2(sum(A.dval[i])) + end + end + else + nothing + end + + # re-zero shadow + if (T <: AbstractFloat) + if A isa Duplicated || A isa DuplicatedNoNeed + ofn.val(A.dval, zero(T)) + elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.batch_width(A))) do i + Base.@_inline_meta + ofn.val(A.dval[i], zero(T)) + nothing + end + end + end + + return (nothing, dx) +end + +end # module + diff --git a/src/initialization.jl b/src/initialization.jl index 212ee0b13c..12425a129b 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -155,6 +155,9 @@ function __init__() @require SpecialFunctions="276daf66-3868-5448-9aa4-cd146d93841b" begin include("../ext/SpecialFunctionsExt.jl") end + @require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" begin + include("../ext/EnzymeCoreExt.jl") + end end # ensure that operations executed by the REPL back-end finish before returning, diff --git a/test/Project.toml b/test/Project.toml index 0da63da920..04b15f52b8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,8 +7,11 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" diff --git a/test/libraries/enzyme.jl b/test/libraries/enzyme.jl new file mode 100644 index 0000000000..8a994232c1 --- /dev/null +++ b/test/libraries/enzyme.jl @@ -0,0 +1,56 @@ +using EnzymeCore +using GPUCompiler +using Enzyme + +@testset "compiler_job_from_backend" begin + @test EnzymeCore.compiler_job_from_backend(CUDABackend(), typeof(()->nothing), Tuple{}) isa GPUCompiler.CompilerJob +end + +function square_kernel!(x) + i = threadIdx().x + x[i] *= x[i] + sync_threads() + return nothing +end + +# basic squaring on GPU +function square!(x) + @cuda blocks = 1 threads = length(x) square_kernel!(x) + return nothing +end + +@testset "Forward Kernel" begin + A = CUDA.rand(64) + dA = CUDA.ones(64) + A .= (1:1:64) + dA .= 1 + Enzyme.autodiff(Forward, square!, Duplicated(A, dA)) + @test all(dA .≈ (2:2:128)) + + A = CUDA.rand(32) + dA = CUDA.ones(32) + dA2 = CUDA.ones(32) + A .= (1:1:32) + dA .= 1 + dA2 .= 3 + Enzyme.autodiff(Forward, square!, BatchDuplicated(A, (dA, dA2))) + @test all(dA .≈ (2:2:64)) + @test all(dA2 .≈ 3*(2:2:64)) +end + +@testset "Forward Fill!" begin + A = CUDA.ones(64) + dA = CUDA.ones(64) + Enzyme.autodiff(Forward, fill!, Duplicated(A, dA), Duplicated(2.0, 3.0)) + @test all(A .≈ 2.0) + @test all(dA .≈ 3.0) +end + +@testset "Reverse Fill!" begin + A = CUDA.zeros(64) + dA = CUDA.ones(64) + res = Enzyme.autodiff(Reverse, fill!, Const, Duplicated(A, dA), Active(1.0))[1][2] + @test res ≈ 64 + @test all(A .≈ 1) + @test all(dA .≈ 0) +end diff --git a/test/setup.jl b/test/setup.jl index 492bd1cba0..c2ce7c1fbe 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -1,6 +1,9 @@ using Distributed, Test, CUDA using CUDA: i32 +# ensure CUDA.jl is functional +@assert CUDA.functional(true) + # GPUArrays has a testsuite that isn't part of the main package. # Include it directly. import GPUArrays From 8b6a2a43aa3c1947080eee84709d722ec31fe03b Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 15 May 2024 13:14:03 +0200 Subject: [PATCH 2/3] Bump version. --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 35c19d89e6..1b5d7c95e1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "CUDA" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "5.3.3" +version = "5.3.4" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From c373258510cabf295600e9e075140947ea73f407 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 15 May 2024 12:08:12 -0700 Subject: [PATCH 3/3] Remove EnzymeCore dependency (#2382) --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1b5d7c95e1..db2558153c 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ CUDA_Runtime_Discovery = "1af6417a-86b4-443c-805f-a4643ffb695f" CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"