From 076f14b73771babf2abb1c5765d880364b7e20d4 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 4 Jun 2019 11:20:44 -0400 Subject: [PATCH] Make it possible to override the name of the compiled kernel (#414) --- Project.toml | 2 +- src/compiler/common.jl | 7 ++++--- src/compiler/irgen.jl | 8 ++++++-- src/device/runtime.jl | 2 +- src/execution.jl | 11 +++++++---- src/reflection.jl | 3 ++- test/device/execution.jl | 6 ++++++ 7 files changed, 27 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index f86653f17e..4b847fd24f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "CUDAnative" uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" -version = "2.1.2" +version = "2.1.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/compiler/common.jl b/src/compiler/common.jl index 021b762ee0..79b9acd4ac 100644 --- a/src/compiler/common.jl +++ b/src/compiler/common.jl @@ -12,11 +12,12 @@ struct CompilerJob maxthreads::Union{Nothing,CuDim} blocks_per_sm::Union{Nothing,Integer} maxregs::Union{Nothing,Integer} + name::Union{Nothing,String} - CompilerJob(f, tt, cap, kernel; + CompilerJob(f, tt, cap, kernel; name=nothing, minthreads=nothing, maxthreads=nothing, blocks_per_sm=nothing, maxregs=nothing) = - new(f, tt, cap, kernel, minthreads, maxthreads, blocks_per_sm, maxregs) + new(f, tt, cap, kernel, minthreads, maxthreads, blocks_per_sm, maxregs, name) end # global job reference @@ -26,7 +27,7 @@ current_job = nothing function signature(job::CompilerJob) - fn = typeof(job.f).name.mt.name + fn = something(job.name, nameof(job.f)) args = join(job.tt.parameters, ", ") return "$fn($(join(job.tt.parameters, ", ")))" end diff --git a/src/compiler/irgen.jl b/src/compiler/irgen.jl index 766b3b95b1..f9c26751e9 100644 --- a/src/compiler/irgen.jl +++ b/src/compiler/irgen.jl @@ -12,7 +12,7 @@ end # make function names safe for PTX safe_fn(fn::String) = replace(fn, r"[^A-Za-z0-9_]"=>"_") -safe_fn(f::Core.Function) = safe_fn(String(typeof(f).name.mt.name)) +safe_fn(f::Core.Function) = safe_fn(String(nameof(f))) safe_fn(f::LLVM.Function) = safe_fn(LLVM.name(f)) # generate a pseudo-backtrace from a stack of methods being emitted @@ -211,7 +211,11 @@ function irgen(job::CompilerJob, method_instance::Core.MethodInstance, world) end # rename the entry point - llvmfn = replace(LLVM.name(entry), r"_\d+$"=>"") + if job.name !== nothing + llvmfn = safe_fn(string("julia_", job.name)) + else + llvmfn = replace(LLVM.name(entry), r"_\d+$"=>"") + end ## append a global unique counter global globalUnique globalUnique += 1 diff --git a/src/device/runtime.jl b/src/device/runtime.jl index ee6094773b..b3addaf05e 100644 --- a/src/device/runtime.jl +++ b/src/device/runtime.jl @@ -67,7 +67,7 @@ get(name::Symbol) = methods[name] # you can always specify `llvm_name` to influence that. Never use an LLVM name that starts # with `julia_` or the function might clash with other compiled functions. function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=nothing; - name=typeof(def).name.mt.name, llvm_name="ptx_$name") + name=nameof(def), llvm_name="ptx_$name") meth = RuntimeMethodInstance(def, return_type, types, name, llvm_return_type, llvm_types, llvm_name) diff --git a/src/execution.jl b/src/execution.jl index e9daa2fdec..32001b8425 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -9,7 +9,7 @@ export @cuda, cudaconvert, cufunction, dynamic_cufunction, nearest_warpsize # the code it generates, or the execution function split_kwargs(kwargs) macro_kws = [:dynamic] - compiler_kws = [:minthreads, :maxthreads, :blocks_per_sm, :maxregs] + compiler_kws = [:minthreads, :maxthreads, :blocks_per_sm, :maxregs, :name] call_kws = [:cooperative, :blocks, :threads, :shmem, :stream] macro_kwargs = [] compiler_kwargs = [] @@ -338,12 +338,13 @@ The following keyword arguments are supported: multiprocessor - `maxregs`: the maximum number of registers to be allocated to a single thread (only supported on LLVM 4.0+) +- `name`: override the name that the kernel will have in the generated code The output of this function is automatically cached, i.e. you can simply call `cufunction` in a hot path without degrading performance. New code will be generated automatically, when when function changes, or when different types or keyword arguments are provided. """ -@generated function cufunction(f::Core.Function, tt::Type=Tuple{}; kwargs...) +@generated function cufunction(f::Core.Function, tt::Type=Tuple{}; name=nothing, kwargs...) tt = Base.to_tuple_type(tt.parameters[1]) sig = Base.signature_type(f, tt) t = Tuple(tt.parameters) @@ -367,6 +368,7 @@ when function changes, or when different types or keyword arguments are provided ctx = CuCurrentContext() key = hash(age, $precomp_key) key = hash(ctx, key) + key = hash(name, key) key = hash(kwargs, key) for nf in 1:nfields(f) # mix in the values of any captured variable @@ -375,13 +377,14 @@ when function changes, or when different types or keyword arguments are provided if !haskey(compilecache, key) dev = device(ctx) cap = supported_capability(dev) - fun, mod = compile(:cuda, cap, f, tt; kwargs...) + fun, mod = compile(:cuda, cap, f, tt; name=name, kwargs...) kernel = HostKernel{f,tt}(ctx, mod, fun) @debug begin ver = version(kernel) mem = memory(kernel) reg = registers(kernel) - """Compiled $f to PTX $(ver.ptx) for SM $(ver.binary) using $reg registers. + fn = something(name, nameof(f)) + """Compiled $fn to PTX $(ver.ptx) for SM $(ver.binary) using $reg registers. Memory usage: $(Base.format_bytes(mem.local)) local, $(Base.format_bytes(mem.shared)) shared, $(Base.format_bytes(mem.constant)) constant""" end compilecache[key] = kernel diff --git a/src/reflection.jl b/src/reflection.jl index 167098d999..3c040ad017 100644 --- a/src/reflection.jl +++ b/src/reflection.jl @@ -282,7 +282,8 @@ Evaluates the expression `ex` and dumps all intermediate forms of code to the di macro device_code(ex...) only(xs) = (@assert length(xs) == 1; first(xs)) function hook(job::CompilerJob; dir::AbstractString) - fn = "$(typeof(job.f).name.mt.name)_$(globalUnique+1)" + name = something(job.name, nameof(job.f)) + fn = "$(name)_$(globalUnique+1)" mkpath(dir) open(joinpath(dir, "$fn.lowered.jl"), "w") do io diff --git a/test/device/execution.jl b/test/device/execution.jl index 0d93fb9c41..74e03eb04b 100644 --- a/test/device/execution.jl +++ b/test/device/execution.jl @@ -69,6 +69,12 @@ end end @test occursin("Body::Union{}", err) end + + # set name of kernel + @test occursin("ptxcall_mykernel", sprint(io->(@device_code_llvm io=io begin + k = cufunction(dummy, name="mykernel") + k() + end))) end