We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cc @vchuravy
Requires next EnzymeCore bump.
using Enzyme, CUDA using Enzyme: EnzymeRules # Enzyme.API.printall!(true) function square_kernel!(x) i = threadIdx().x x[i] *= x[i] sync_threads() return nothing end # basic squaring on GPU function square!(x) @cuda blocks = (1,) threads = length(x) square_kernel!(x) return x end @inline makedup() = () @inline makedup(a, b, args...) = (Duplicated(a, b), makedup(args...)...) @inline dedup() = () @inline dedup(a::Duplicated{T}, args...) where T = (a.val, a.dval, dedup(args...)...) function metaf(fn, args...) Enzyme.autodiff_deferred(Forward, fn, Const, makedup(args...)...) nothing end function EnzymeRules.forward(ofn::Const{typeof(CUDA.cufunction)}, ::Type{<:Duplicated}, f::Const{F}, tt::Const{TT}; kwargs...) where {F, TT} T2 = (F,) for p in tt.val.parameters T2 = (T2..., p, p) end TT2 = Tuple{T2...} res = ofn.val(typeof(metaf), TT2; kwargs...) pres = CUDA.HostKernel{F, tt.val}(f.val, res.fun, res.state) return Duplicated(pres, pres) end function EnzymeRules.forward(ofn::Duplicated{CUDA.HostKernel{F, TT}}, ::Type{Const{Nothing}}, args...; kwargs...) where {F, TT} T2 = (F,) for p in TT.parameters T2 = (T2..., p, p) end TT2 = Tuple{T2...} c2 = CUDA.HostKernel{typeof(metaf), TT2}(metaf, ofn.val.fun, ofn.val.state) res = c2(ofn.val.f, dedup(args...)...; kwargs...) return nothing end x = CUDA.rand(3) dx = CUDA.ones(3) # square!(x) @show x Enzyme.autodiff(Forward, square!, Duplicated(x, dx)) @show x, dx
x = Float32[0.99226356, 0.6712043, 0.34854883] (x, dx) = (Float32[0.98458695, 0.45051524, 0.121486284], Float32[1.9845271, 1.3424087, 0.69709766])
The text was updated successfully, but these errors were encountered:
I guess this is covered by some of the Enzyme PRs?
Sorry, something went wrong.
No this is distinct (and should be rebased)
No branches or pull requests
cc @vchuravy
Requires next EnzymeCore bump.
The text was updated successfully, but these errors were encountered: