Skip to content

Commit

Permalink
Merge pull request JuliaGPU#204 from JuliaGPU/tb/unsafe_wrappers
Browse files Browse the repository at this point in the history
Generate unsafe wrappers that return the status instead of throwing.
  • Loading branch information
maleadt committed Feb 6, 2020
2 parents 181808f + 9ccfc4b commit e47a654
Show file tree
Hide file tree
Showing 10 changed files with 1,155 additions and 1,140 deletions.
6 changes: 4 additions & 2 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ version = "0.2.0"

[[CUDAapi]]
deps = ["Libdl", "Logging"]
git-tree-sha1 = "56a813440ac98a1aa64672ab460a1512552211a7"
git-tree-sha1 = "7248357aafe1755b5416ae72053d50b59f78eaa6"
repo-rev = "08e9c61"
repo-url = "https://github.com/JuliaGPU/CUDAapi.jl.git"
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
version = "2.1.0"
version = "3.0.0"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "CUDAdrv"
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
version = "5.1"
version = "5.1.0"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand All @@ -9,7 +9,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[compat]
CEnum = "0.2"
CUDAapi = "2.1"
CUDAapi = "3.0"
julia = "1.0"

[extras]
Expand Down
28 changes: 23 additions & 5 deletions res/wrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,36 @@ mutable struct State
edits::Vector{Edit}
end

# insert `@check` before each `ccall` when it returns a checked type
# insert `@checked` before each function with a `ccall` returning a checked type`
const checked_types = [
"CUresult",
]
function insert_check(x, state)
if x isa CSTParser.EXPR && x.typ == CSTParser.Call && x.args[1].val == "ccall"
if x isa CSTParser.EXPR && x.typ == CSTParser.FunctionDef
_, def, body, _ = x.args
@assert body isa CSTParser.EXPR && body.typ == CSTParser.Block
@assert length(body.args) == 1

# Clang.jl-generated ccalls should be directly part of a function definition
call = body.args[1]
@assert call isa CSTParser.EXPR && call.typ == CSTParser.Call && call.args[1].val == "ccall"

# get the ccall return type
rv = x.args[5]
rv = call.args[5]

if rv.val in checked_types
push!(state.edits, Edit(state.offset, "@check "))
push!(state.edits, Edit(state.offset, "@checked "))
end
end
end

# rewrite ordinary `ccall`s to `@runtime_ccall`
function rewrite_ccall(x, state)
if x isa CSTParser.EXPR && x.typ == CSTParser.Call && x.args[1].val == "ccall"
push!(state.edits, Edit(state.offset, "@runtime_"))
end
end


## indenting passes

Expand Down Expand Up @@ -171,7 +186,7 @@ function wrap_at_comma(x, state, indent, offset, column)
end

function indent_ccall(x, state)
if x isa CSTParser.EXPR && x.typ == CSTParser.Call && x.args[1].val == "ccall"
if x isa CSTParser.EXPR && x.typ == CSTParser.MacroCall && x.args[1].args[2].val == "runtime_ccall"
# figure out how much to indent by looking at where the expr starts
line = findlast(y -> state.offset >= y[2], state.lines) # index, not the actual number
line_indent, line_offset = state.lines[line]
Expand Down Expand Up @@ -235,6 +250,9 @@ function process(name, headers...; kwargs...)
state.offset = 0
pass(ast, state, insert_check)

state.offset = 0
pass(ast, state, rewrite_ccall)

# apply
state.offset = 0
sort!(state.edits, lt = (a,b) -> first(a.loc) < first(b.loc), rev = true)
Expand Down
2 changes: 0 additions & 2 deletions src/CUDAdrv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ include("error.jl")
include("libcuda.jl")
include("libcuda_aliases.jl")

include("util.jl")

# high-level wrappers
include("version.jl")
include("devices.jl")
Expand Down
30 changes: 23 additions & 7 deletions src/error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,7 @@ end

Base.show(io::IO, err::CuError) = print(io, "CuError($(err.code))")

# define shorthands that give CuError objects
for code in instances(cudaError_enum)
local name = String(Symbol(code))
shorthand = Symbol(name[6:end]) # strip the CUDA_ prefix
@eval const $shorthand = CuError($code)
end
@enum_without_prefix cudaError_enum CUDA_


## API call wrapper
Expand Down Expand Up @@ -136,6 +131,27 @@ end
throw(CuError(res))
end

macro checked(ex)
# parse the function definition
@assert Meta.isexpr(ex, :function)
sig = ex.args[1]
@assert Meta.isexpr(sig, :call)
body = ex.args[2]
@assert Meta.isexpr(body, :block)
@assert length(body.args) == 2 # line number node and a single call

# generate a "safe" version that performs a check
safe_body = Expr(:block, body.args[1], :(@check $(body.args[2])))
safe_sig = Expr(:call, sig.args[1], sig.args[2:end]...)
safe_def = Expr(:function, safe_sig, safe_body)

# generate a "unsafe" version that returns the error code instead
unsafe_sig = Expr(:call, Symbol("unsafe_", sig.args[1]), sig.args[2:end]...)
unsafe_def = Expr(:function, unsafe_sig, body)

return esc(:($safe_def, $unsafe_def))
end

macro check(ex)
fun = Symbol(decode_ccall_function(ex))
init = if !in(fun, preinit_apicalls)
Expand All @@ -145,7 +161,7 @@ macro check(ex)
$init

res = $(esc(ex))
if res != CUDA_SUCCESS
if res != SUCCESS
throw_api_error(res)
end

Expand Down
Loading

0 comments on commit e47a654

Please sign in to comment.