Skip to content

Commit

Permalink
cfunction macro: extend cfunction capabilities
Browse files Browse the repository at this point in the history
Provide static support for handling dynamic calls and closures
  • Loading branch information
vtjnash committed Apr 4, 2018
1 parent 2ba69e8 commit 81f6524
Show file tree
Hide file tree
Showing 35 changed files with 1,105 additions and 405 deletions.
44 changes: 38 additions & 6 deletions base/c.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,56 @@ respectively.
"""
cglobal

struct CFunction
ptr::Ptr{Cvoid}
f::Any
_1::Ptr{Cvoid}
_2::Ptr{Cvoid}
let construtor = false end
end
unsafe_convert(::Type{Ptr{Cvoid}}, cf::CFunction) = cf.ptr

"""
cfunction(f::Function, returntype::Type, argtypes::Type) -> Ptr{Cvoid}
@cfunction(callable, ReturnType, (ArgumentTypes...,)) -> Ptr{Cvoid}
@cfunction(\$callable, ReturnType, (ArgumentTypes...,)) -> CFunction
Generate a C-callable function pointer from the Julia function `closure`
for the given type signature.
Note that the argument type tuple must be a literal tuple, and not a tuple-valued variable or expression
(although it can include a splat expression). And that these arguments will be evaluated in global scope
during compile-time (not deferred until runtime).
Adding a `\$` in front of the function argument changes this to instead create a runtime closure
over the local variable `callable`.
Generate C-callable function pointer from the Julia function `f`. Type annotation of the return
value in the callback function is a must for situations where Julia cannot infer the return
type automatically.
See [manual section on ccall and cfunction usage](@ref Calling-C-and-Fortran-Code).
# Examples
```julia-repl
julia> function foo(x::Int, y::Int)
return x + y
end
julia> cfunction(foo, Int, Tuple{Int,Int})
julia> @cfunction(foo, Int, (Int, Int))
Ptr{Cvoid} @0x000000001b82fcd0
```
"""
cfunction(f, r, a) = ccall(:jl_function_ptr, Ptr{Cvoid}, (Any, Any, Any), f, r, a)
macro cfunction(f, at, rt)
if !(isa(rt, Expr) && rt.head === :tuple)
throw(ArgumentError("@cfunction argument types must be a literal tuple"))
end
rt.head = :call
pushfirst!(rt.args, GlobalRef(Core, :svec))
if isa(f, Expr) && f.head === :$
fptr = f.args[1]
typ = CFunction
else
fptr = QuoteNode(f)
typ = Ptr{Cvoid}
end
cfun = Expr(:cfunction, typ, fptr, at, rt, QuoteNode(:ccall))
return esc(cfun)
end

if ccall(:jl_is_char_signed, Ref{Bool}, ())
const Cchar = Int8
Expand Down
25 changes: 21 additions & 4 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,8 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt
return abstract_call_gf_by_type(f, argtypes, atype, sv)
end

function abstract_eval_call(e::Expr, vtypes::VarTable, sv::InferenceState)
argtypes = Any[abstract_eval(a, vtypes, sv) for a in e.args]
# wrapper around `abstract_call` for first computing if `f` is available
function abstract_eval_call(fargs::Union{Tuple{},Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
#print("call ", e.args[1], argtypes, "\n\n")
for x in argtypes
x === Bottom && return Bottom
Expand All @@ -689,7 +689,7 @@ function abstract_eval_call(e::Expr, vtypes::VarTable, sv::InferenceState)
end
return abstract_call_gf_by_type(nothing, argtypes, argtypes_to_type(argtypes), sv)
end
return abstract_call(f, e.args, argtypes, vtypes, sv)
return abstract_call(f, fargs, argtypes, vtypes, sv)
end

function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
Expand Down Expand Up @@ -730,6 +730,18 @@ function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
return T
end

function abstract_eval_cfunction(e::Expr, vtypes::VarTable, sv::InferenceState)
f = abstract_eval(e.args[2], vtypes, sv)
# rt = sp_type_rewrap(e.args[3], sv.linfo, true)
at = Any[ sp_type_rewrap(argt, sv.linfo, false) for argt in e.args[4]::SimpleVector ]
pushfirst!(at, f)
# this may be the wrong world for the call,
# but some of the result is likely to be valid anyways
# and that may help generate better codegen
abstract_eval_call((), at, vtypes, sv)
nothing
end

function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
if isa(e, QuoteNode)
return AbstractEvalConstant((e::QuoteNode).value)
Expand All @@ -748,7 +760,8 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
end
e = e::Expr
if e.head === :call
t = abstract_eval_call(e, vtypes, sv)
argtypes = Any[ abstract_eval(a, vtypes, sv) for a in e.args ]
t = abstract_eval_call(e.args, argtypes, vtypes, sv)
elseif e.head === :new
t = instanceof_tfunc(abstract_eval(e.args[1], vtypes, sv))[1]
for i = 2:length(e.args)
Expand All @@ -767,6 +780,10 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
t = Bottom
end
end
elseif e.head === :cfunction
t = e.args[1]
isa(t, Type) || (t = Any)
abstract_eval_cfunction(e, vtypes, sv)
elseif e.head === :static_parameter
n = e.args[1]
t = Any
Expand Down
15 changes: 12 additions & 3 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -769,17 +769,26 @@ function substitute!(
head = e.head
if head === :static_parameter
return quoted(spvals[e.args[1]])
elseif head === :cfunction
@assert !isa(spsig, UnionAll) || !isempty(spvals)
if !(e.args[2] isa QuoteNode) # very common no-op
e.args[2] = substitute!(e.args[2], na, argexprs, spsig, spvals, offset, boundscheck)
end
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals)
e.args[4] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt
in e.args[4] ]...)
elseif head === :foreigncall
@assert !isa(spsig, UnionAll) || !isempty(spvals)
for i = 1:length(e.args)
if i == 2
e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals)
elseif i == 3
argtuple = Any[
e.args[3] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt
in e.args[3] ]
e.args[3] = svec(argtuple...)
in e.args[3] ]...)
elseif i == 4
@assert isa((e.args[4]::QuoteNode).value, Symbol)
elseif i == 5
Expand Down
11 changes: 7 additions & 4 deletions base/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const VALID_EXPR_HEADS = IdDict{Any,Any}(
:meta => 0:typemax(Int),
:global => 1:1,
:foreigncall => 3:typemax(Int),
:cfunction => 6:6,
:isdefined => 1:1,
:simdloop => 0:0,
:gc_preserve_begin => 0:typemax(Int),
Expand Down Expand Up @@ -139,9 +140,11 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_
end
validate_val!(x.args[1])
elseif head === :call || head === :invoke || head == :gc_preserve_end || head === :meta ||
head === :inbounds || head === :foreigncall || head === :const || head === :enter ||
head === :leave || head === :method || head === :global || head === :static_parameter ||
head === :new || head === :thunk || head === :simdloop || head === :throw_undef_if_not || head === :unreachable
head === :inbounds || head === :foreigncall || head === :cfunction ||
head === :const || head === :enter || head === :leave ||
head === :method || head === :global || head === :static_parameter ||
head === :new || head === :thunk || head === :simdloop ||
head === :throw_undef_if_not || head === :unreachable
validate_val!(x)
else
push!(errors, InvalidCodeError("invalid statement", x))
Expand Down Expand Up @@ -221,7 +224,7 @@ end

function is_valid_rvalue(lhs, x)
is_valid_argument(x) && return true
if isa(x, Expr) && x.head in (:new, :the_exception, :isdefined, :call, :invoke, :foreigncall, :gc_preserve_begin)
if isa(x, Expr) && x.head in (:new, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin)
return true
# TODO: disallow `globalref = call` when .typ field is removed
#return isa(lhs, SSAValue) || isa(lhs, Slot)
Expand Down
6 changes: 6 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,12 @@ end

# PR #23066
@deprecate cfunction(f, r, a::Tuple) cfunction(f, r, Tuple{a...})
@noinline function cfunction(f, r, a)
@nospecialize(f, r, a)
depwarn("The function `cfunction` is now written as a macro `@cfunction`.", :cfunction)
return ccall(:jl_function_ptr, Ptr{Cvoid}, (Any, Any, Any), f, r, a)
end
export cfunction

# PR 23341
@eval GMP @deprecate gmp_version() version() false
Expand Down
2 changes: 1 addition & 1 deletion base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ export
withenv,

# C interface
cfunction,
@cfunction,
cglobal,
disable_sigint,
pointer,
Expand Down
19 changes: 13 additions & 6 deletions base/libuv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,20 @@ function process_events(block::Bool)
end
end

function uv_alloc_buf end
function uv_readcb end
function uv_writecb_task end
function uv_return_spawn end
function uv_asynccb end
function uv_timercb end

function reinit_stdio()
global uv_jl_alloc_buf = cfunction(uv_alloc_buf, Cvoid, Tuple{Ptr{Cvoid}, Csize_t, Ptr{Cvoid}})
global uv_jl_readcb = cfunction(uv_readcb, Cvoid, Tuple{Ptr{Cvoid}, Cssize_t, Ptr{Cvoid}})
global uv_jl_writecb_task = cfunction(uv_writecb_task, Cvoid, Tuple{Ptr{Cvoid}, Cint})
global uv_jl_return_spawn = cfunction(uv_return_spawn, Cvoid, Tuple{Ptr{Cvoid}, Int64, Int32})
global uv_jl_asynccb = cfunction(uv_asynccb, Cvoid, Tuple{Ptr{Cvoid}})
global uv_jl_timercb = cfunction(uv_timercb, Cvoid, Tuple{Ptr{Cvoid}})
global uv_jl_alloc_buf = @cfunction(uv_alloc_buf, Cvoid, (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}))
global uv_jl_readcb = @cfunction(uv_readcb, Cvoid, (Ptr{Cvoid}, Cssize_t, Ptr{Cvoid}))
global uv_jl_writecb_task = @cfunction(uv_writecb_task, Cvoid, (Ptr{Cvoid}, Cint))
global uv_jl_return_spawn = @cfunction(uv_return_spawn, Cvoid, (Ptr{Cvoid}, Int64, Int32))
global uv_jl_asynccb = @cfunction(uv_asynccb, Cvoid, (Ptr{Cvoid},))
global uv_jl_timercb = @cfunction(uv_timercb, Cvoid, (Ptr{Cvoid},))

global uv_eventloop = ccall(:jl_global_event_loop, Ptr{Cvoid}, ())
global stdin = init_stdio(ccall(:jl_stdin_stream, Ptr{Cvoid}, ()))
Expand Down
61 changes: 32 additions & 29 deletions base/threadcall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@ const max_ccall_threads = parse(Int, get(ENV, "UV_THREADPOOL_SIZE", "4"))
const thread_notifiers = Union{Condition, Nothing}[nothing for i in 1:max_ccall_threads]
const threadcall_restrictor = Semaphore(max_ccall_threads)

function notify_fun(idx)
global thread_notifiers
notify(thread_notifiers[idx])
return
end

"""
@threadcall((cfunc, clib), rettype, (argtypes...), argvals...)
Expand All @@ -36,62 +30,71 @@ macro threadcall(f, rettype, argtypes, argvals...)
argvals = map(esc, argvals)

# construct non-allocating wrapper to call C function
wrapper = :(function wrapper(args_ptr::Ptr{Cvoid}, retval_ptr::Ptr{Cvoid})
wrapper = :(function (args_ptr::Ptr{Cvoid}, retval_ptr::Ptr{Cvoid})
p = args_ptr
# the rest of the body is created below
end)
body = wrapper.args[2].args
args = Symbol[]
for (i,T) in enumerate(argtypes)
for (i, T) in enumerate(argtypes)
arg = Symbol("arg", i)
push!(body, :($arg = unsafe_load(convert(Ptr{$T}, p))))
push!(body, :(p += sizeof($T)))
push!(body, :(p += Core.sizeof($T)))
push!(args, arg)
end
push!(body, :(ret = ccall($f, $rettype, ($(argtypes...),), $(args...))))
push!(body, :(unsafe_store!(convert(Ptr{$rettype}, retval_ptr), ret)))
push!(body, :(return sizeof($rettype)))
push!(body, :(return Int(Core.sizeof($rettype))))

# return code to generate wrapper function and send work request thread queue
:(let
$wrapper
do_threadcall(wrapper, $rettype, Any[$(argtypes...)], Any[$(argvals...)])
wrapper = Expr(Symbol("hygienic-scope"), wrapper, @__MODULE__)
return :(let fun_ptr = @cfunction($wrapper, Int, (Ptr{Cvoid}, Ptr{Cvoid}))
do_threadcall(fun_ptr, $rettype, Any[$(argtypes...)], Any[$(argvals...)])
end)
end

function do_threadcall(wrapper::Function, rettype::Type, argtypes::Vector, argvals::Vector)
function do_threadcall(fun_ptr::Ptr{Cvoid}, rettype::Type, argtypes::Vector, argvals::Vector)
# generate function pointer
fun_ptr = cfunction(wrapper, Int, Tuple{Ptr{Cvoid}, Ptr{Cvoid}})
c_notify_fun = cfunction(notify_fun, Cvoid, Tuple{Cint})
c_notify_fun = @cfunction(
function notify_fun(idx)
global thread_notifiers
notify(thread_notifiers[idx])
return
end, Cvoid, (Cint,))

# cconvert, root and unsafe_convert arguments
roots = Any[]
args_size = isempty(argtypes) ? 0 : sum(sizeof, argtypes)
args_size = isempty(argtypes) ? 0 : sum(Core.sizeof, argtypes)
args_arr = Vector{UInt8}(undef, args_size)
ptr = pointer(args_arr)
for (T, x) in zip(argtypes, argvals)
isbits(T) || throw(ArgumentError("threadcall requires isbits argument types"))
y = cconvert(T, x)
push!(roots, y)
unsafe_store!(convert(Ptr{T}, ptr), unsafe_convert(T, y))
ptr += sizeof(T)
unsafe_store!(convert(Ptr{T}, ptr), unsafe_convert(T, y)::T)
ptr += Core.sizeof(T)
end

# create return buffer
ret_arr = Vector{UInt8}(undef, sizeof(rettype))
ret_arr = Vector{UInt8}(undef, Core.sizeof(rettype))

# wait for a worker thread to be available
acquire(threadcall_restrictor)
idx = findfirst(isequal(nothing), thread_notifiers)::Int
thread_notifiers[idx] = Condition()

# queue up the work to be done
ccall(:jl_queue_work, Cvoid,
(Ptr{Cvoid}, Ptr{UInt8}, Ptr{UInt8}, Ptr{Cvoid}, Cint),
fun_ptr, args_arr, ret_arr, c_notify_fun, idx)
GC.@preserve args_arr ret_arr roots begin
# queue up the work to be done
ccall(:jl_queue_work, Cvoid,
(Ptr{Cvoid}, Ptr{UInt8}, Ptr{UInt8}, Ptr{Cvoid}, Cint),
fun_ptr, args_arr, ret_arr, c_notify_fun, idx)

# wait for a result & return it
wait(thread_notifiers[idx])
thread_notifiers[idx] = nothing
release(threadcall_restrictor)
# wait for a result & return it
wait(thread_notifiers[idx])
thread_notifiers[idx] = nothing
release(threadcall_restrictor)

unsafe_load(convert(Ptr{rettype}, pointer(ret_arr)))
r = unsafe_load(convert(Ptr{rettype}, pointer(ret_arr)))
end
return r
end
2 changes: 1 addition & 1 deletion doc/src/base/c.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
```@docs
ccall
Core.Intrinsics.cglobal
Base.cfunction
Base.@cfunction
Base.unsafe_convert
Base.cconvert
Base.unsafe_load
Expand Down
Loading

0 comments on commit 81f6524

Please sign in to comment.