Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EnzymeRules support for CUDA.jl (for forward mode here) #1811

Closed
wsmoses opened this issue Mar 19, 2023 · 2 comments
Closed

Add EnzymeRules support for CUDA.jl (for forward mode here) #1811

wsmoses opened this issue Mar 19, 2023 · 2 comments
Labels
enhancement New feature or request

Comments

@wsmoses
Copy link
Contributor

wsmoses commented Mar 19, 2023

cc @vchuravy

Requires next EnzymeCore bump.

using Enzyme, CUDA
using Enzyme: EnzymeRules

# Enzyme.API.printall!(true)

function square_kernel!(x)
    i = threadIdx().x
    x[i] *= x[i]
    sync_threads()
    return nothing
end

# basic squaring on GPU
function square!(x)
    @cuda blocks = (1,) threads = length(x) square_kernel!(x)
    return x
end

@inline makedup() = ()
@inline makedup(a, b, args...) = (Duplicated(a, b), makedup(args...)...)

@inline dedup() = ()
@inline dedup(a::Duplicated{T}, args...) where T = (a.val, a.dval, dedup(args...)...)

function metaf(fn, args...) 
    Enzyme.autodiff_deferred(Forward, fn, Const, makedup(args...)...)
    nothing
end

function EnzymeRules.forward(ofn::Const{typeof(CUDA.cufunction)}, ::Type{<:Duplicated}, f::Const{F}, tt::Const{TT}; kwargs...) where {F, TT}
    T2 = (F,)
    for p in tt.val.parameters
        T2 = (T2..., p, p)
    end
    TT2 = Tuple{T2...}
    res = ofn.val(typeof(metaf), TT2; kwargs...)
    pres = CUDA.HostKernel{F, tt.val}(f.val, res.fun, res.state)
    return Duplicated(pres, pres)
end

function EnzymeRules.forward(ofn::Duplicated{CUDA.HostKernel{F, TT}}, ::Type{Const{Nothing}}, args...; kwargs...) where {F, TT}
    T2 = (F,)
    for p in TT.parameters
        T2 = (T2..., p, p)
    end
    TT2 = Tuple{T2...}
    c2 = CUDA.HostKernel{typeof(metaf), TT2}(metaf, ofn.val.fun, ofn.val.state)

    res = c2(ofn.val.f, dedup(args...)...; kwargs...)

    return nothing
end
x = CUDA.rand(3)
dx = CUDA.ones(3)
# square!(x)

@show x
Enzyme.autodiff(Forward, square!, Duplicated(x, dx))
@show x, dx
x = Float32[0.99226356, 0.6712043, 0.34854883]
(x, dx) = (Float32[0.98458695, 0.45051524, 0.121486284], Float32[1.9845271, 1.3424087, 0.69709766])
@wsmoses wsmoses added the enhancement New feature or request label Mar 19, 2023
@maleadt
Copy link
Member

maleadt commented Apr 27, 2024

I guess this is covered by some of the Enzyme PRs?

@maleadt maleadt closed this as completed Apr 27, 2024
@wsmoses
Copy link
Contributor Author

wsmoses commented Apr 27, 2024

No this is distinct (and should be rebased)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants