Skip to content

Commit

Permalink
Merge pull request #2387 from JuliaGPU/tb/tweaks
Browse files Browse the repository at this point in the history
Tweaks to prevent context construction on some operations
  • Loading branch information
maleadt committed May 17, 2024
2 parents c2d444b + 049bc9f commit e076034
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ mutable struct CuArray{T,N,M} <: AbstractGPUArray{T,N}
finalizer(unsafe_free!, obj)
end

function CuArray{T,N}(data::DataRef{Managed{M}}, dims::Dims{N}; stream=CUDA.stream(),
function CuArray{T,N}(data::DataRef{Managed{M}}, dims::Dims{N};
maxsize::Int=prod(dims) * sizeof(T), offset::Int=0) where {T,N,M}
check_eltype(T)
obj = new{T,N,M}(data, maxsize, offset, dims)
Expand Down
8 changes: 4 additions & 4 deletions src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,22 +209,22 @@ end
# current device, but if LLVM doesn't support it, we can target an older capability
# and pass a different `-arch` to `ptxas`.
ptx_support = ptx_compat(cuda_ptx)
requested_cap = something(cap, min(capability(dev), maximum(ptx_support.cap)))
requested_cap = @something(cap, min(capability(dev), maximum(ptx_support.cap)))
llvm_caps = filter(<=(requested_cap), llvm_support.cap)
cuda_caps = filter(<=(capability(dev)), cuda_support.cap)
if cap !== nothing
# the user requested a specific compute capability.
## use the highest capability supported by LLVM
isempty(llvm_caps) &&
error("Requested compute capability $cap is not supported by LLVM $(LLVM.version())")
llvm_cap = maximum(llvm_caps)
## use the capability as-is to invoke CUDA
cuda_cap = cap
else
# try to do the best thing (i.e., use the highest compute capability)
## use the highest capability supported by LLVM
isempty(llvm_caps) &&
error("Compute capability $(requested_cap) is not supported by LLVM $(LLVM.version())")
llvm_cap = maximum(llvm_caps)
## use the highest capability supported by CUDA
cuda_caps = filter(<=(capability(dev)), cuda_support.cap)
isempty(cuda_caps) &&
error("Compute capability $(requested_cap) is not supported by CUDA driver $(driver_version()) / runtime $(runtime_version())")
cuda_cap = maximum(cuda_caps)
Expand Down
17 changes: 14 additions & 3 deletions src/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,9 +528,14 @@ function maybe_synchronize(managed::Managed)
end

function Base.convert(::Type{CuPtr{T}}, managed::Managed{M}) where {T,M}
state = active_state()
# let null pointers pass through as-is
ptr = convert(CuPtr{T}, managed.mem)
if ptr == CU_NULL
return ptr
end

# accessing memory during stream capture: taint the memory so that we always synchronize
state = active_state()
if is_capturing(state.stream)
managed.captured = true
end
Expand Down Expand Up @@ -564,10 +569,16 @@ function Base.convert(::Type{CuPtr{T}}, managed::Managed{M}) where {T,M}
end

managed.dirty = true
convert(CuPtr{T}, managed.mem)
return ptr
end

function Base.convert(::Type{Ptr{T}}, managed::Managed{M}) where {T,M}
# let null pointers pass through as-is
ptr = convert(Ptr{T}, managed.mem)
if ptr == C_NULL
return ptr
end

# accessing memory on the CPU: only allowed for host or unified allocations
if M == DeviceMemory
throw(ArgumentError(
Expand All @@ -583,7 +594,7 @@ function Base.convert(::Type{Ptr{T}}, managed::Managed{M}) where {T,M}

# make sure any work on the memory has finished.
maybe_synchronize(managed)
convert(Ptr{T}, managed.mem)
return ptr
end


Expand Down

0 comments on commit e076034

Please sign in to comment.