Skip to content

Commit

Permalink
Simplify some more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed May 20, 2020
1 parent 61f4bd2 commit 9b4d267
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
5 changes: 2 additions & 3 deletions test/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,13 @@ end
@testset "stripping invariant.load" begin
function kernel(ptr, x)
i = CUDA.threadIdx_x()
@inbounds unsafe_store!(ptr, x[i], 1)
@inbounds ptr[] = x[i]
return
end

arr = CuArray(zeros(Float64))
ptr = pointer(arr)

@cuda kernel(ptr, (1., 2., ))
@cuda kernel(arr, (1., 2., ))
@test Array(arr)[] == 1.
end

Expand Down
16 changes: 7 additions & 9 deletions test/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ len = prod(dims)
function kernel(input, output)
i = (blockIdx().x-1) * blockDim().x + threadIdx().x

val = unsafe_load(input, i)
unsafe_store!(output, val, i)
val = input[i]
output[i] = val

return
end
Expand All @@ -183,7 +183,7 @@ len = prod(dims)
input_dev = CuArray(input)
output_dev = CuArray(output)

@cuda threads=len kernel(pointer(input_dev), pointer(output_dev))
@cuda threads=len kernel(input_dev, output_dev)
@test input Array(output_dev)
end

Expand Down Expand Up @@ -493,21 +493,19 @@ end
@testset "argument count" begin
val = [0]
val_dev = CuArray(val)
cuda_ptr = pointer(val_dev)
ptr = CUDA.DevicePtr{Int}(cuda_ptr)
for i in (1, 10, 20, 35)
for i in (1, 10, 20, 34)
variables = ('a':'z'..., 'A':'Z'...)
params = [Symbol(variables[j]) for j in 1:i]
# generate a kernel
body = quote
function kernel($(params...))
unsafe_store!($ptr, $(Expr(:call, :+, params...)))
function kernel(arr, $(params...))
arr[] = $(Expr(:call, :+, params...))
return
end
end
eval(body)
args = [j for j in 1:i]
call = Expr(:call, :kernel, args...)
call = Expr(:call, :kernel, val_dev, args...)
cudacall = :(@cuda $call)
eval(cudacall)
@test Array(val_dev)[1] == sum(args)
Expand Down

0 comments on commit 9b4d267

Please sign in to comment.