Skip to content

Commit

Permalink
Make it possible to override the name of the compiled kernel (JuliaGP…
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy authored and maleadt committed Jun 4, 2019
1 parent b891556 commit 076f14b
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "CUDAnative"
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
version = "2.1.2"
version = "2.1.3"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
7 changes: 4 additions & 3 deletions src/compiler/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ struct CompilerJob
maxthreads::Union{Nothing,CuDim}
blocks_per_sm::Union{Nothing,Integer}
maxregs::Union{Nothing,Integer}
name::Union{Nothing,String}

CompilerJob(f, tt, cap, kernel;
CompilerJob(f, tt, cap, kernel; name=nothing,
minthreads=nothing, maxthreads=nothing,
blocks_per_sm=nothing, maxregs=nothing) =
new(f, tt, cap, kernel, minthreads, maxthreads, blocks_per_sm, maxregs)
new(f, tt, cap, kernel, minthreads, maxthreads, blocks_per_sm, maxregs, name)
end

# global job reference
Expand All @@ -26,7 +27,7 @@ current_job = nothing


function signature(job::CompilerJob)
fn = typeof(job.f).name.mt.name
fn = something(job.name, nameof(job.f))
args = join(job.tt.parameters, ", ")
return "$fn($(join(job.tt.parameters, ", ")))"
end
Expand Down
8 changes: 6 additions & 2 deletions src/compiler/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ end

# make function names safe for PTX
safe_fn(fn::String) = replace(fn, r"[^A-Za-z0-9_]"=>"_")
safe_fn(f::Core.Function) = safe_fn(String(typeof(f).name.mt.name))
safe_fn(f::Core.Function) = safe_fn(String(nameof(f)))
safe_fn(f::LLVM.Function) = safe_fn(LLVM.name(f))

# generate a pseudo-backtrace from a stack of methods being emitted
Expand Down Expand Up @@ -211,7 +211,11 @@ function irgen(job::CompilerJob, method_instance::Core.MethodInstance, world)
end

# rename the entry point
llvmfn = replace(LLVM.name(entry), r"_\d+$"=>"")
if job.name !== nothing
llvmfn = safe_fn(string("julia_", job.name))
else
llvmfn = replace(LLVM.name(entry), r"_\d+$"=>"")
end
## append a global unique counter
global globalUnique
globalUnique += 1
Expand Down
2 changes: 1 addition & 1 deletion src/device/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ get(name::Symbol) = methods[name]
# you can always specify `llvm_name` to influence that. Never use an LLVM name that starts
# with `julia_` or the function might clash with other compiled functions.
function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=nothing;
name=typeof(def).name.mt.name, llvm_name="ptx_$name")
name=nameof(def), llvm_name="ptx_$name")
meth = RuntimeMethodInstance(def,
return_type, types, name,
llvm_return_type, llvm_types, llvm_name)
Expand Down
11 changes: 7 additions & 4 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export @cuda, cudaconvert, cufunction, dynamic_cufunction, nearest_warpsize
# the code it generates, or the execution
function split_kwargs(kwargs)
macro_kws = [:dynamic]
compiler_kws = [:minthreads, :maxthreads, :blocks_per_sm, :maxregs]
compiler_kws = [:minthreads, :maxthreads, :blocks_per_sm, :maxregs, :name]
call_kws = [:cooperative, :blocks, :threads, :shmem, :stream]
macro_kwargs = []
compiler_kwargs = []
Expand Down Expand Up @@ -338,12 +338,13 @@ The following keyword arguments are supported:
multiprocessor
- `maxregs`: the maximum number of registers to be allocated to a single thread (only
supported on LLVM 4.0+)
- `name`: override the name that the kernel will have in the generated code
The output of this function is automatically cached, i.e. you can simply call `cufunction`
in a hot path without degrading performance. New code will be generated automatically, when
when function changes, or when different types or keyword arguments are provided.
"""
@generated function cufunction(f::Core.Function, tt::Type=Tuple{}; kwargs...)
@generated function cufunction(f::Core.Function, tt::Type=Tuple{}; name=nothing, kwargs...)
tt = Base.to_tuple_type(tt.parameters[1])
sig = Base.signature_type(f, tt)
t = Tuple(tt.parameters)
Expand All @@ -367,6 +368,7 @@ when function changes, or when different types or keyword arguments are provided
ctx = CuCurrentContext()
key = hash(age, $precomp_key)
key = hash(ctx, key)
key = hash(name, key)
key = hash(kwargs, key)
for nf in 1:nfields(f)
# mix in the values of any captured variable
Expand All @@ -375,13 +377,14 @@ when function changes, or when different types or keyword arguments are provided
if !haskey(compilecache, key)
dev = device(ctx)
cap = supported_capability(dev)
fun, mod = compile(:cuda, cap, f, tt; kwargs...)
fun, mod = compile(:cuda, cap, f, tt; name=name, kwargs...)
kernel = HostKernel{f,tt}(ctx, mod, fun)
@debug begin
ver = version(kernel)
mem = memory(kernel)
reg = registers(kernel)
"""Compiled $f to PTX $(ver.ptx) for SM $(ver.binary) using $reg registers.
fn = something(name, nameof(f))
"""Compiled $fn to PTX $(ver.ptx) for SM $(ver.binary) using $reg registers.
Memory usage: $(Base.format_bytes(mem.local)) local, $(Base.format_bytes(mem.shared)) shared, $(Base.format_bytes(mem.constant)) constant"""
end
compilecache[key] = kernel
Expand Down
3 changes: 2 additions & 1 deletion src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ Evaluates the expression `ex` and dumps all intermediate forms of code to the di
macro device_code(ex...)
only(xs) = (@assert length(xs) == 1; first(xs))
function hook(job::CompilerJob; dir::AbstractString)
fn = "$(typeof(job.f).name.mt.name)_$(globalUnique+1)"
name = something(job.name, nameof(job.f))
fn = "$(name)_$(globalUnique+1)"
mkpath(dir)

open(joinpath(dir, "$fn.lowered.jl"), "w") do io
Expand Down
6 changes: 6 additions & 0 deletions test/device/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ end
end
@test occursin("Body::Union{}", err)
end

# set name of kernel
@test occursin("ptxcall_mykernel", sprint(io->(@device_code_llvm io=io begin
k = cufunction(dummy, name="mykernel")
k()
end)))
end


Expand Down

0 comments on commit 076f14b

Please sign in to comment.