diff --git a/ext/StaticArraysExt.jl b/ext/StaticArraysExt.jl new file mode 100644 index 0000000000..7887a2244b --- /dev/null +++ b/ext/StaticArraysExt.jl @@ -0,0 +1,14 @@ +# compatibility with StaticArrays + +module StaticArraysExt + +using ..CUDA +using ..CUDA: @device_override, @print_and_throw + +import StaticArrays + +# same quirk as for some Base methods in src/device/quirks.jl +@device_override @noinline StaticArrays.dimension_mismatch_fail(::Type{SA}, a::AbstractArray) where {SA<:StaticArrays.StaticArray} = + @print_and_throw("DimensionMismatch while trying to convert to StaticArray: Expected and actual length of input array differ.") + +end # extension module \ No newline at end of file diff --git a/src/CUDA.jl b/src/CUDA.jl index bd539273ff..1fb628f0f9 100644 --- a/src/CUDA.jl +++ b/src/CUDA.jl @@ -122,6 +122,9 @@ include("CUDAKernels.jl") import .CUDAKernels: CUDABackend export CUDABackend +# StaticArrays is still a direct dependency, so directly include the extension +include("../ext/StaticArraysExt.jl") + include("precompile.jl") end diff --git a/test/libraries/staticarrays.jl b/test/libraries/staticarrays.jl new file mode 100644 index 0000000000..98f1435c0d --- /dev/null +++ b/test/libraries/staticarrays.jl @@ -0,0 +1,30 @@ +using LinearAlgebra: mul! +using StaticArrays + +@testset "StaticArrays" begin + function batched_matvec(ms::CuArray, vs::CuArray) + function matvec_kernel(out, ms, vs, ::Val{N}, ::Val{M}) where {N, M} + i = (blockIdx().x - 1) * blockDim().x + threadIdx().x + # Call constructors without @inbounds. + # This asserts that the @device_override + # for StaticArrays.dimension_mismatch_fail() works. + m = SMatrix{N, M, Float32}(@view ms[:, :, i]) + v = SVector{M, Float32}(@view vs[:, i]) + out[:, i] .= m * v + nothing + end + + out = similar(ms, (size(ms, 1), size(ms, 3))) + @cuda threads=size(ms, 3) matvec_kernel(out, ms, vs, Val(size(ms, 1)), Val(size(ms, 2))) + out + end + + function batched_matvec(ms, vs) + out = similar(ms, (size(ms, 1), size(ms, 3))) + foreach((o, m, v) -> mul!(o, m, v), eachcol(out), eachslice(ms; dims=3), eachcol(vs)) + out + end + + ms, vs = randn(Float32, 3, 2, 4), randn(Float32, 2, 4) + @test batched_matvec(ms, vs) ≈ Array(batched_matvec(cu(ms), cu(vs))) +end \ No newline at end of file