diff --git a/lib/cublas/CUBLAS.jl b/lib/cublas/CUBLAS.jl index 5391f2932f..d119af70c6 100644 --- a/lib/cublas/CUBLAS.jl +++ b/lib/cublas/CUBLAS.jl @@ -71,9 +71,20 @@ function math_mode!(handle, mode) return end -# cache for created, but unused handles -const idle_handles = HandleCache{CuContext,cublasHandle_t}() -const idle_xt_handles = HandleCache{Any,cublasXtHandle_t}() + +## handles + +function handle_ctor(ctx) + context!(ctx) do + cublasCreate() + end +end +function handle_dtor(ctx, handle) + context!(ctx; skip_destroyed=true) do + cublasDestroy_v2(handle) + end +end +const idle_handles = HandleCache{CuContext,cublasHandle_t}(handle_ctor, handle_dtor) function handle() cuda = CUDA.active_state() @@ -86,20 +97,12 @@ function handle() # get library state @noinline function new_state(cuda) - new_handle = pop!(idle_handles, cuda.context) do - cublasCreate() - end - + new_handle = pop!(idle_handles, cuda.context) finalizer(current_task()) do task - push!(idle_handles, cuda.context, new_handle) do - context!(cuda.context; skip_destroyed=true) do - cublasDestroy_v2(new_handle) - end - end + push!(idle_handles, cuda.context, new_handle) end cublasSetStream_v2(new_handle, cuda.stream) - math_mode!(new_handle, cuda.math_mode) (; handle=new_handle, cuda.stream, cuda.math_mode) @@ -129,6 +132,34 @@ function handle() return state.handle end + +## xt handles + +function xt_handle_ctor(ctx) + context!(ctx) do + cublasXtCreate() + end +end +function xt_handle_dtor(ctx, handle) + context!(ctx; skip_destroyed=true) do + cublasXtDestroy(handle) + end +end +const idle_xt_handles = + HandleCache{CuContext,cublasXtHandle_t}(xt_handle_ctor, xt_handle_dtor) + +function devices!(devs::Vector{CuDevice}) + task_local_storage(:CUBLASxt_devices, sort(devs; by=deviceid)) + return +end + +devices() = get!(task_local_storage(), :CUBLASxt_devices) do + # by default, select all devices + sort(collect(CUDA.devices()); by=deviceid) +end::Vector{CuDevice} + +ndevices() = length(devices()) + function xt_handle() cuda = CUDA.active_state() @@ -147,15 +178,9 @@ function xt_handle() # get library state @noinline function new_state(cuda) - new_handle = pop!(idle_xt_handles, cuda.context) do - cublasXtCreate() - end - + new_handle = pop!(idle_xt_handles, cuda.context) finalizer(current_task()) do task - push!(idle_xt_handles, cuda.context, new_handle) do - # TODO: which context do we need to destroy this on? - cublasXtDestroy(new_handle) - end + push!(idle_xt_handles, cuda.context, new_handle) end devs = convert.(Cint, devices()) @@ -170,18 +195,6 @@ function xt_handle() return state.handle end -function devices!(devs::Vector{CuDevice}) - task_local_storage(:CUBLASxt_devices, sort(devs; by=deviceid)) - return -end - -devices() = get!(task_local_storage(), :CUBLASxt_devices) do - # by default, select all devices - sort(collect(CUDA.devices()); by=deviceid) -end::Vector{CuDevice} - -ndevices() = length(devices()) - ## logging diff --git a/lib/cudnn/Project.toml b/lib/cudnn/Project.toml index 1b1da3192b..a596fde750 100644 --- a/lib/cudnn/Project.toml +++ b/lib/cudnn/Project.toml @@ -11,7 +11,7 @@ CUDNN_jll = "62b44479-cb7b-5706-934f-f13b2eb2e645" [compat] CEnum = "0.2, 0.3, 0.4, 0.5" -CUDA = "~5.3, ~5.4" +CUDA = "~5.4" CUDA_Runtime_Discovery = "0.2" CUDNN_jll = "~9.0" julia = "1.8" diff --git a/lib/cudnn/src/cuDNN.jl b/lib/cudnn/src/cuDNN.jl index 3dbfc4cbe5..ad5bfdaf7b 100644 --- a/lib/cudnn/src/cuDNN.jl +++ b/lib/cudnn/src/cuDNN.jl @@ -62,8 +62,20 @@ function math_mode(mode=CUDA.math_mode()) end end -# cache for created, but unused handles -const idle_handles = HandleCache{CuContext,cudnnHandle_t}() + +## handles + +function handle_ctor(ctx) + context!(ctx) do + cudnnCreate() + end +end +function handle_dtor(ctx, handle) + context!(ctx; skip_destroyed=true) do + cudnnDestroy(handle) + end +end +const idle_handles = HandleCache{CuContext,cudnnHandle_t}(handle_ctor, handle_dtor) function handle() cuda = CUDA.active_state() @@ -76,16 +88,9 @@ function handle() # get library state @noinline function new_state(cuda) - new_handle = pop!(idle_handles, cuda.context) do - cudnnCreate() - end - + new_handle = pop!(idle_handles, cuda.context) finalizer(current_task()) do task - push!(idle_handles, cuda.context, new_handle) do - context!(cuda.context; skip_destroyed=true) do - cudnnDestroy(new_handle) - end - end + push!(idle_handles, cuda.context, new_handle) end cudnnSetStream(new_handle, cuda.stream) diff --git a/lib/cufft/wrappers.jl b/lib/cufft/wrappers.jl index 3ab01f1eb9..5d2c670082 100644 --- a/lib/cufft/wrappers.jl +++ b/lib/cufft/wrappers.jl @@ -43,7 +43,7 @@ function cufftMakePlan(xtype::cufftType_t, xdims::Dims, region) end if ((region...,) == ((1:nrank)...,)) # handle simple case, transforming the first nrank dimensions, ... simply! (for robustness) - # arguments are: plan, rank, transform-sizes, inembed, istride, idist, onembed, ostride, odist, type batch + # arguments are: plan, rank, transform-sizes, inembed, istride, idist, onembed, ostride, odist, type batch cufftMakePlanMany(handle, nrank, Cint[rsz...], C_NULL, 1, 1, C_NULL, 1, 1, xtype, batch, worksize_ref) else @@ -151,13 +151,12 @@ function cufftMakePlan(xtype::cufftType_t, xdims::Dims, region) handle, worksize_ref[] end -# plan cache -const cufftHandleCacheKey = Tuple{CuContext, cufftType_t, Dims, Any} -const idle_handles = HandleCache{cufftHandleCacheKey, cufftHandle}() -function cufftGetPlan(args...) - ctx = context() - handle = pop!(idle_handles, (ctx, args...)) do +## plan cache + +const cufftHandleCacheKey = Tuple{CuContext, cufftType_t, Dims, Any} +function handle_ctor((ctx, args...)) + context!(ctx) do # make the plan handle, worksize = cufftMakePlan(args...) @@ -165,15 +164,22 @@ function cufftGetPlan(args...) # instead relying on the automatic allocation strategy. handle end +end +function handle_dtor((ctx, args...), handle) + context!(ctx; skip_destroyed=true) do + cufftDestroy(handle) + end +end +const idle_handles = HandleCache{cufftHandleCacheKey, cufftHandle}(handle_ctor, handle_dtor) + +function cufftGetPlan(args...) + ctx = context() + handle = pop!(idle_handles, (ctx, args...)) - # assign to the current stream cufftSetStream(handle, stream()) return handle end function cufftReleasePlan(plan) - push!(idle_handles, plan) do - cufftDestroy(plan) - end - + push!(idle_handles, plan) end diff --git a/lib/curand/CURAND.jl b/lib/curand/CURAND.jl index 1790652d04..10c8d4858e 100644 --- a/lib/curand/CURAND.jl +++ b/lib/curand/CURAND.jl @@ -21,8 +21,21 @@ include("wrappers.jl") # high-level integrations include("random.jl") -# cache for created, but unused handles -const idle_curand_rngs = HandleCache{CuContext,RNG}() + +## handles + +function handle_ctor(ctx) + context!(ctx) do + RNG() + end +end +function handle_dtor(ctx, handle) + context!(ctx; skip_destroyed=true) do + # no need to do anything, as the RNG is collected by its finalizer + # TODO: early free? + end +end +const idle_curand_rngs = HandleCache{CuContext,RNG}(handle_ctor, handle_dtor) function default_rng() cuda = CUDA.active_state() @@ -35,17 +48,13 @@ function default_rng() # get library state @noinline function new_state(cuda) - new_rng = pop!(idle_curand_rngs, cuda.context) do - RNG() - end - + new_rng = pop!(idle_curand_rngs, cuda.context) finalizer(current_task()) do task - push!(idle_curand_rngs, cuda.context, new_rng) do - # no need to do anything, as the RNG is collected by its finalizer - end + push!(idle_curand_rngs, cuda.context, new_rng) end Random.seed!(new_rng) + (; rng=new_rng) end state = get!(states, cuda.context) do diff --git a/lib/cusolver/CUSOLVER.jl b/lib/cusolver/CUSOLVER.jl index c91f32ada2..f923b91374 100644 --- a/lib/cusolver/CUSOLVER.jl +++ b/lib/cusolver/CUSOLVER.jl @@ -44,9 +44,21 @@ include("multigpu.jl") # high-level integrations include("linalg.jl") -# cache for created, but unused handles -const idle_dense_handles = HandleCache{CuContext,cusolverDnHandle_t}() -const idle_sparse_handles = HandleCache{CuContext,cusolverSpHandle_t}() + +## dense handles + +function dense_handle_ctor(ctx) + context!(ctx) do + cusolverDnCreate() + end +end +function dense_handle_dtor(ctx, handle) + context!(ctx; skip_destroyed=true) do + cusolverDnDestroy(handle) + end +end +const idle_dense_handles = + HandleCache{CuContext,cusolverDnHandle_t}(dense_handle_ctor, dense_handle_dtor) function dense_handle() cuda = CUDA.active_state() @@ -59,19 +71,13 @@ function dense_handle() # get library state @noinline function new_state(cuda) - new_handle = pop!(idle_dense_handles, cuda.context) do - cusolverDnCreate() - end - + new_handle = pop!(idle_dense_handles, cuda.context) finalizer(current_task()) do task - push!(idle_dense_handles, cuda.context, new_handle) do - context!(cuda.context; skip_destroyed=true) do - cusolverDnDestroy(new_handle) - end - end + push!(idle_dense_handles, cuda.context, new_handle) end cusolverDnSetStream(new_handle, cuda.stream) + (; handle=new_handle, cuda.stream) end state = get!(states, cuda.context) do @@ -90,6 +96,22 @@ function dense_handle() return state.handle end + +## sparse handles + +function sparse_handle_ctor(ctx) + context!(ctx) do + cusolverSpCreate() + end +end +function sparse_handle_dtor(ctx, handle) + context!(ctx; skip_destroyed=true) do + cusolverSpDestroy(handle) + end +end +const idle_sparse_handles = + HandleCache{CuContext,cusolverSpHandle_t}(sparse_handle_ctor, sparse_handle_dtor) + function sparse_handle() cuda = CUDA.active_state() @@ -101,19 +123,13 @@ function sparse_handle() # get or create handle @noinline function new_state(cuda) - new_handle = pop!(idle_sparse_handles, cuda.context) do - cusolverSpCreate() - end - + new_handle = pop!(idle_sparse_handles, cuda.context) finalizer(current_task()) do task - push!(idle_sparse_handles, cuda.context, new_handle) do - context!(cuda.context; skip_destroyed=true) do - cusolverSpDestroy(new_handle) - end - end + push!(idle_sparse_handles, cuda.context, new_handle) end cusolverSpSetStream(new_handle, cuda.stream) + (; handle=new_handle, cuda.stream) end state = get!(states, cuda.context) do @@ -132,6 +148,23 @@ function sparse_handle() return state.handle end + +## mg handles + +function devices!(devs::Vector{CuDevice}) + task_local_storage(:CUSOLVERmg_devices, sort(devs; by=deviceid)) + return +end + +devices() = get!(task_local_storage(), :CUSOLVERmg_devices) do + # by default, select only the first device + [first(CUDA.devices())] + # TODO: select all devices + #sort(collect(CUDA.devices()); by=deviceid) +end::Vector{CuDevice} + +ndevices() = length(devices()) + function mg_handle() cuda = CUDA.active_state() @@ -171,18 +204,4 @@ function mg_handle() return state.handle end -function devices!(devs::Vector{CuDevice}) - task_local_storage(:CUSOLVERmg_devices, sort(devs; by=deviceid)) - return -end - -devices() = get!(task_local_storage(), :CUSOLVERmg_devices) do - # by default, select only the first device - [first(CUDA.devices())] - # TODO: select all devices - #sort(collect(CUDA.devices()); by=deviceid) -end::Vector{CuDevice} - -ndevices() = length(devices()) - end diff --git a/lib/cusparse/CUSPARSE.jl b/lib/cusparse/CUSPARSE.jl index 0aa9eb21fa..812059edde 100644 --- a/lib/cusparse/CUSPARSE.jl +++ b/lib/cusparse/CUSPARSE.jl @@ -54,8 +54,20 @@ include("reduce.jl") include("batched.jl") -# cache for created, but unused handles -const idle_handles = HandleCache{CuContext,cusparseHandle_t}() + +## handles + +function handle_ctor(ctx) + context!(ctx) do + cusparseCreate() + end +end +function handle_dtor(ctx, handle) + context!(ctx; skip_destroyed=true) do + cusparseDestroy(handle) + end +end +const idle_handles = HandleCache{CuContext,cusparseHandle_t}(handle_ctor, handle_dtor) function handle() cuda = CUDA.active_state() @@ -68,16 +80,9 @@ function handle() # get library state @noinline function new_state(cuda) - new_handle = pop!(idle_handles, cuda.context) do - cusparseCreate() - end - + new_handle = pop!(idle_handles, cuda.context) finalizer(current_task()) do task - push!(idle_handles, cuda.context, new_handle) do - context!(cuda.context; skip_destroyed=true) do - cusparseDestroy(new_handle) - end - end + push!(idle_handles, cuda.context, new_handle) end cusparseSetStream(new_handle, cuda.stream) diff --git a/lib/custatevec/Project.toml b/lib/custatevec/Project.toml index a35310ebe1..cecf4cbf09 100644 --- a/lib/custatevec/Project.toml +++ b/lib/custatevec/Project.toml @@ -11,7 +11,7 @@ cuQuantum_jll = "b75408ef-6fdf-5d74-b65a-7df000ad18e6" [compat] CEnum = "0.2, 0.3, 0.4, 0.5" -CUDA = "~5.3, ~5.4" +CUDA = "~5.4" CUDA_Runtime_Discovery = "0.2" cuQuantum_jll = "~24.03" julia = "1.8" diff --git a/lib/custatevec/src/cuStateVec.jl b/lib/custatevec/src/cuStateVec.jl index 0ec38e4ec0..8dfca844a3 100644 --- a/lib/custatevec/src/cuStateVec.jl +++ b/lib/custatevec/src/cuStateVec.jl @@ -1,7 +1,8 @@ module cuStateVec using CUDA -using CUDA: CUstream, cudaDataType, cudaEvent_t, @checked, HandleCache, with_workspace, libraryPropertyType +using CUDA.APIUtils +using CUDA: CUstream, cudaDataType, cudaEvent_t, libraryPropertyType using CUDA: unsafe_free!, retry_reclaim, initialize_context, isdebug using CEnum: @cenum @@ -31,18 +32,34 @@ const cudaDataType_t = cudaDataType # core library include("libcustatevec.jl") +# low-level wrappers include("error.jl") include("types.jl") +include("wrappers.jl") include("statevec.jl") + +## handles + +function handle_ctor(ctx) + context!(ctx) do + custatevecCreate() + end +end +function handle_dtor(ctx, handle) + context!(ctx; skip_destroyed=true) do + custatevecDestroy(handle) + end +end +const idle_handles = HandleCache{CuContext,custatevecHandle_t}(handle_ctor, handle_dtor) + +# fat handle, includes a cache struct cuStateVecHandle handle::custatevecHandle_t cache::CuVector{UInt8} end -Base.unsafe_convert(::Type{Ptr{custatevecContext}}, handle::cuStateVecHandle) = handle.handle - -# cache for created, but unused handles -const idle_handles = HandleCache{CuContext,custatevecHandle_t}() +Base.unsafe_convert(::Type{Ptr{custatevecContext}}, handle::cuStateVecHandle) = + handle.handle function handle() cuda = CUDA.active_state() @@ -55,24 +72,14 @@ function handle() # get library state @noinline function new_state(cuda) - new_handle = pop!(idle_handles, cuda.context) do - handle = Ref{custatevecHandle_t}() - custatevecCreate(handle) - handle[] - end + new_handle = pop!(idle_handles, cuda.context) cache = CuVector{UInt8}(undef, 0) fat_handle = cuStateVecHandle(new_handle, cache) finalizer(current_task()) do task - # wipe the cache when storing the handle - resize!(cache, 0) - - push!(idle_handles, cuda.context, new_handle) do - context!(cuda.context; skip_destroyed=true) do - custatevecDestroy(new_handle) - end - end + CUDA.unsafe_free!(cache) + push!(idle_handles, cuda.context, new_handle) end custatevecSetStream(new_handle, cuda.stream) @@ -95,14 +102,6 @@ function handle() return state.handle end -function version() - ver = custatevecGetVersion() - major, ver = divrem(ver, 1000) - minor, patch = divrem(ver, 100) - - VersionNumber(major, minor, patch) -end - ## logging diff --git a/lib/custatevec/src/wrappers.jl b/lib/custatevec/src/wrappers.jl new file mode 100644 index 0000000000..37b2a1b439 --- /dev/null +++ b/lib/custatevec/src/wrappers.jl @@ -0,0 +1,13 @@ +function custatevecCreate() + handle = Ref{custatevecHandle_t}() + custatevecCreate(handle) + return handle[] +end + +function version() + ver = custatevecGetVersion() + major, ver = divrem(ver, 1000) + minor, patch = divrem(ver, 100) + + VersionNumber(major, minor, patch) +end diff --git a/lib/cutensor/src/cuTENSOR.jl b/lib/cutensor/src/cuTENSOR.jl index 89c5ec1ad9..c7751e8aa6 100644 --- a/lib/cutensor/src/cuTENSOR.jl +++ b/lib/cutensor/src/cuTENSOR.jl @@ -36,8 +36,20 @@ include("operations.jl") # high-level integrations include("interfaces.jl") -# cache for created, but unused handles -const idle_handles = HandleCache{CuContext,cutensorHandle_t}() + +## handles + +function handle_ctor(ctx) + context!(ctx) do + cutensorCreate() + end +end +function handle_dtor(ctx, handle) + context!(ctx; skip_destroyed=true) do + cutensorDestroy(handle) + end +end +const idle_handles = HandleCache{CuContext,cutensorHandle_t}(handle_ctor, handle_dtor) function handle() cuda = CUDA.active_state() @@ -50,14 +62,9 @@ function handle() # get library state @noinline function new_state(cuda) - new_handle = pop!(idle_handles, cuda.context) do - cutensorCreate() - end - + new_handle = pop!(idle_handles, cuda.context) finalizer(current_task()) do task - push!(idle_handles, cuda.context, new_handle) do - cutensorDestroy(new_handle) - end + push!(idle_handles, cuda.context, new_handle) end (; handle=new_handle) diff --git a/lib/cutensornet/Project.toml b/lib/cutensornet/Project.toml index b1f1e03055..ac738392af 100644 --- a/lib/cutensornet/Project.toml +++ b/lib/cutensornet/Project.toml @@ -13,9 +13,9 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [compat] CEnum = "0.2, 0.3, 0.4, 0.5" -CUDA = "~5.1, ~5.2, ~5.3, ~5.4" +CUDA = "~5.4" CUDA_Runtime_Discovery = "0.2" cuQuantum_jll = "~24.3" -cuTENSOR = "~2.0" +cuTENSOR = "2" julia = "1.8" LinearAlgebra = "1" diff --git a/lib/cutensornet/src/cuTensorNet.jl b/lib/cutensornet/src/cuTensorNet.jl index d07219d5f4..4abfbd0495 100644 --- a/lib/cutensornet/src/cuTensorNet.jl +++ b/lib/cutensornet/src/cuTensorNet.jl @@ -2,7 +2,8 @@ module cuTensorNet using LinearAlgebra using CUDA -using CUDA: CUstream, cudaDataType, @checked, HandleCache, with_workspace +using CUDA.APIUtils +using CUDA: CUstream, cudaDataType using CUDA: retry_reclaim, initialize_context, isdebug, cuDoubleComplex using cuTENSOR @@ -35,12 +36,26 @@ const cudaDataType_t = cudaDataType # core library include("libcutensornet.jl") +# low-level wrappers include("error.jl") include("types.jl") +include("wrappers.jl") include("tensornet.jl") -# cache for created, but unused handles -const idle_handles = HandleCache{CuContext,cutensornetHandle_t}() + +## handles + +function handle_ctor(ctx) + context!(ctx) do + cutensornetCreate() + end +end +function handle_dtor(ctx, handle) + context!(ctx; skip_destroyed=true) do + cutensornetDestroy(handle) + end +end +const idle_handles = HandleCache{CuContext,cutensornetHandle_t}(handle_ctor, handle_dtor) function handle() cuda = CUDA.active_state() @@ -53,18 +68,9 @@ function handle() # get library state @noinline function new_state(cuda) - new_handle = pop!(idle_handles, cuda.context) do - handle = Ref{cutensornetHandle_t}() - cutensornetCreate(handle) - handle[] - end - + new_handle = pop!(idle_handles, cuda.context) finalizer(current_task()) do task - push!(idle_handles, cuda.context, new_handle) do - context!(cuda.context; skip_destroyed=true) do - cutensornetDestroy(new_handle) - end - end + push!(idle_handles, cuda.context, new_handle) end (; handle=new_handle) @@ -76,22 +82,6 @@ function handle() return state.handle end -function version() - ver = cutensornetGetVersion() - major, ver = divrem(ver, 10000) - minor, patch = divrem(ver, 100) - - VersionNumber(major, minor, patch) -end - -function cuda_version() - ver = cutensornetGetCudartVersion() - major, ver = divrem(ver, 1000) - minor, patch = divrem(ver, 10) - - VersionNumber(major, minor, patch) -end - ## logging diff --git a/lib/cutensornet/src/wrappers.jl b/lib/cutensornet/src/wrappers.jl new file mode 100644 index 0000000000..1ef9c0f347 --- /dev/null +++ b/lib/cutensornet/src/wrappers.jl @@ -0,0 +1,21 @@ +function cutensornetCreate() + handle = Ref{cutensornetHandle_t}() + cutensornetCreate(handle) + handle[] +end + +function version() + ver = cutensornetGetVersion() + major, ver = divrem(ver, 10000) + minor, patch = divrem(ver, 100) + + VersionNumber(major, minor, patch) +end + +function cuda_version() + ver = cutensornetGetCudartVersion() + major, ver = divrem(ver, 1000) + minor, patch = divrem(ver, 10) + + VersionNumber(major, minor, patch) +end diff --git a/lib/utils/cache.jl b/lib/utils/cache.jl index c74f4d3878..474427e89a 100644 --- a/lib/utils/cache.jl +++ b/lib/utils/cache.jl @@ -1,27 +1,32 @@ # a cache for library handles -# TODO: -# - keep track of the (estimated?) size of cache contents -# - clean the caches when memory is needed. this will require registering the destructor -# upfront, so that it can set the environment (e.g. switch to the appropriate context). -# alternatively, register the `unsafe_free!`` methods with the pool instead of the cache. - export HandleCache struct HandleCache{K,V} - active_handles::Set{Pair{K,V}} # for debugging, and to prevent handle finalization + ctor + dtor + + active_handles::Set{Pair{K,V}} idle_handles::Dict{K,Vector{V}} - lock::ReentrantLock + lock::Base.ThreadSynchronizer + # XXX: we use a thread-safe spinlock because the handle cache is used from finalizers. + # once finalizers run on their own thread, use a regular ReentrantLock max_entries::Int - function HandleCache{K,V}(max_entries::Int=32) where {K,V} - return new{K,V}(Set{Pair{K,V}}(), Dict{K,Vector{V}}(), ReentrantLock(), max_entries) + function HandleCache{K,V}(ctor, dtor; max_entries::Int=32) where {K,V} + obj = new{K,V}(ctor, dtor, Set{Pair{K,V}}(), Dict{K,Vector{V}}(), + Base.ThreadSynchronizer(), max_entries) + + # register a hook to wipe the current context's cache when under memory pressure + push!(CUDA.reclaim_hooks, ()->empty!(obj)) + + return obj end end # remove a handle from the cache, or create a new one -function Base.pop!(ctor::Function, cache::HandleCache{K,V}, key::K) where {K,V} +function Base.pop!(cache::HandleCache{K,V}, key::K) where {K,V} # check the cache handle = @lock cache.lock begin if !haskey(cache.idle_handles, key) || isempty(cache.idle_handles[key]) @@ -35,7 +40,8 @@ function Base.pop!(ctor::Function, cache::HandleCache{K,V}, key::K) where {K,V} # we could (and used to) run `GC.gc(false)` here to free up old handles, # but that can be expensive when using lots of short-lived tasks. if handle === nothing - handle = ctor() + CUDA.maybe_collect() + handle = cache.ctor(key) end # add the handle to the active set @@ -47,7 +53,7 @@ function Base.pop!(ctor::Function, cache::HandleCache{K,V}, key::K) where {K,V} end # put a handle in the cache, or destroy it if it doesn't fit -function Base.push!(dtor::Function, cache::HandleCache{K,V}, key::K, handle::V) where {K,V} +function Base.push!(cache::HandleCache{K,V}, key::K, handle::V) where {K,V} saved = @lock cache.lock begin delete!(cache.active_handles, key=>handle) @@ -65,12 +71,12 @@ function Base.push!(dtor::Function, cache::HandleCache{K,V}, key::K, handle::V) end if !saved - dtor() + cache.dtor(key, handle) end end # shorthand version to put a handle back without having to remember the key -function Base.push!(dtor::Function, cache::HandleCache{K,V}, handle::V) where {K,V} +function Base.push!(cache::HandleCache{K,V}, handle::V) where {K,V} key = @lock cache.lock begin key = nothing for entry in cache.active_handles @@ -85,5 +91,23 @@ function Base.push!(dtor::Function, cache::HandleCache{K,V}, handle::V) where {K key end - push!(dtor, cache, key, handle) + push!(cache, key, handle) +end + +# empty the cache +# XXX: often we only need to empty the handles for a single context, however, we don't +# know for sure that the key is a context (see e.g. cuFFT), so we wipe everything +function Base.empty!(cache::HandleCache{K,V}) where {K,V} + handles = @lock cache.lock begin + all_handles = Pair{K,V}[] + for (key, handles) in cache.idle_handles, handle in handles + push!(all_handles, key=>handle) + end + empty!(cache.idle_handles) + all_handles + end + + for (key,handle) in handles + cache.dtor(key, handle) + end end diff --git a/src/memory.jl b/src/memory.jl index e961b59bab..ad97f922a2 100644 --- a/src/memory.jl +++ b/src/memory.jl @@ -419,6 +419,8 @@ function Base.showerror(io::IO, err::OutOfGPUMemoryError) end end +const reclaim_hooks = Any[] + """ retry_reclaim(retry_if) do # code that may fail due to insufficient GPU memory @@ -462,6 +464,10 @@ end elseif phase == 5 # in case we had a release threshold configured trim(pool_create(state.device)) + elseif phase == 6 + for hook in reclaim_hooks + hook() + end else break end @@ -470,6 +476,10 @@ end GC.gc(false) elseif phase == 2 GC.gc(true) + elseif phase == 3 + for hook in reclaim_hooks + hook() + end else break end @@ -715,6 +725,9 @@ actually reclaimed. """ function reclaim(sz::Int=typemax(Int)) dev = device() + for hook in reclaim_hooks + hook() + end if stream_ordered(dev) device_synchronize() synchronize(context()) diff --git a/src/random.jl b/src/random.jl index 867bbb5a97..f5a05ae3e1 100644 --- a/src/random.jl +++ b/src/random.jl @@ -190,7 +190,19 @@ end # we keep this for the GPUArrays.jl tests -const idle_gpuarray_rngs = HandleCache{CuContext,GPUArrays.RNG}() +function gpuarrays_rng_ctor(ctx) + context!(ctx) do + N = attribute(device(), DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK) + buf = CuArray{NTuple{4, UInt32}}(undef, N) + GPUArrays.RNG(buf) + end +end +function gpuarrays_rng_dtor(ctx, rng) + context!(ctx; skip_destroyed=true) do + # no need to do anything, as the RNG is collected by its finalizer + end +end +const idle_gpuarray_rngs = HandleCache{CuContext,GPUArrays.RNG}(gpuarrays_rng_ctor, gpuarrays_rng_dtor) function GPUArrays.default_rng(::Type{<:CuArray}) cuda = CUDA.active_state() @@ -203,19 +215,13 @@ function GPUArrays.default_rng(::Type{<:CuArray}) # get library state @noinline function new_state(cuda) - new_rng = pop!(idle_gpuarray_rngs, cuda.context) do - N = attribute(cuda.device, DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK) - buf = CuArray{NTuple{4, UInt32}}(undef, N) - GPUArrays.RNG(buf) - end - + new_rng = pop!(idle_gpuarray_rngs, cuda.context) finalizer(current_task()) do task - push!(idle_gpuarray_rngs, cuda.context, new_rng) do - # no need to do anything, as the RNG is collected by its finalizer - end + push!(idle_gpuarray_rngs, cuda.context, new_rng) end Random.seed!(new_rng) + (; rng=new_rng) end state = get!(states, cuda.context) do