Skip to content

Commit

Permalink
Use at-runtime_ccall to protect against missing libraries.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Nov 12, 2019
1 parent b465e2b commit 8db7c2f
Show file tree
Hide file tree
Showing 5 changed files with 818 additions and 805 deletions.
12 changes: 12 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.2.0"

[[CUDAapi]]
deps = ["Libdl", "Logging"]
git-tree-sha1 = "6eee47385c81ed3b3f716b745697869c712c2df3"
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
version = "2.0.0"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ version = "4.0.3"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[compat]
CEnum = "0.2"
julia = "1"
CUDAapi = "2.0"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
32 changes: 12 additions & 20 deletions src/CUDAdrv.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module CUDAdrv

using CUDAapi

using CEnum

using Printf
Expand Down Expand Up @@ -43,8 +45,16 @@ functional() = __initialized__[]

function __init__()
try
# barrier to avoid compiling `ccall`s to unavailable libraries
inferencebarrier(__hidden_init__)()
if haskey(ENV, "_") && basename(ENV["_"]) == "rr"
error("Running under rr, which is incompatible with CUDA")
end

cuInit(0)

if version() <= v"9"
@warn "CUDAdrv.jl only supports NVIDIA drivers for CUDA 9.0 or higher (yours is for CUDA $(version()))"
end

__initialized__[] = true
catch ex
# don't actually fail to keep the package loadable
Expand All @@ -53,22 +63,4 @@ function __init__()
end
end

if VERSION >= v"1.3.0-DEV.35"
using Base: inferencebarrier
else
inferencebarrier(@nospecialize(x)) = Ref{Any}(x)[]
end

function __hidden_init__()
if haskey(ENV, "_") && basename(ENV["_"]) == "rr"
error("Running under rr, which is incompatible with CUDA")
end

cuInit(0)

if version() <= v"9"
@warn "CUDAdrv.jl only supports NVIDIA drivers for CUDA 9.0 or higher (yours is for CUDA $(version()))"
end
end

end
17 changes: 12 additions & 5 deletions src/error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,18 @@ end
const apicall_hook = Ref{Union{Nothing,Function}}(nothing)

macro check(ex)
# check is used in front of `ccall`s that work on a tuple (fun, lib)
@assert Meta.isexpr(ex, :call)
@assert ex.args[1] == :ccall
@assert Meta.isexpr(ex.args[2], :tuple)
fun = String(ex.args[2].args[1].value)
# check is used in front of `ccall` or `@runtime_ccall`s that work on a tuple (fun, lib)
if Meta.isexpr(ex, :call)
@assert ex.args[1] == :ccall
@assert Meta.isexpr(ex.args[2], :tuple)
fun = String(ex.args[2].args[1].value)
elseif Meta.isexpr(ex, :macrocall)
@assert ex.args[1] == Symbol("@runtime_ccall")
@assert Meta.isexpr(ex.args[3], :tuple)
fun = String(ex.args[3].args[1].value)
else
error("@check should prefix ccall or @runtime_ccall")
end

# strip any version tag (e.g. cuEventDestroy_v2 -> cuEventDestroy)
m = match(r"_v\d+$", fun)
Expand Down
Loading

0 comments on commit 8db7c2f

Please sign in to comment.