Skip to content

Commit

Permalink
Add == for handle-based objects.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Nov 4, 2016
1 parent 19e8338 commit 9124f30
Show file tree
Hide file tree
Showing 8 changed files with 9 additions and 4 deletions.
2 changes: 0 additions & 2 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
* Merge with pointers in CUDArt.jl

* == for all handle-based types?
4 changes: 3 additions & 1 deletion src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ end
(::Type{CuArray{T}}){T,N}(shape::NTuple{N,Int}) = CuArray{T,N}(shape)
(::Type{CuArray{T}}){T}(len::Int) = CuArray{T,1}((len,))

Base.:(==)(a::CuArray, b::CuArray) = a.handle == b.handle
function Base.:(==)(a::CuArray, b::CuArray)
return a.ctx == b.ctx && pointer(a) == pointer(b)
end
Base.unsafe_convert{T}(::Type{DevicePtr{T}}, a::CuArray{T}) = a.devptr
Base.pointer(a::CuArray) = a.devptr

Expand Down
1 change: 1 addition & 0 deletions src/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ immutable CuDevice
end

Base.convert(::Type{CuDevice_t}, dev::CuDevice) = dev.handle
Base.:(==)(a::CuDevice, b::CuDevice) = a.handle == b.handle

"Get the name of a CUDA device"
function name(dev::CuDevice)
Expand Down
1 change: 1 addition & 0 deletions src/events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ immutable CuEvent
end

Base.unsafe_convert(::Type{CuEvent_t}, e::CuEvent) = e.handle
Base.:(==)(a::CuEvent, b::CuEvent) = a.handle == b.handle

record(e::CuEvent, stream::CuStream=CuDefaultStream()) =
@apicall(:cuEventRecord, (CuEvent_t, CuStream_t), e, stream)
Expand Down
1 change: 1 addition & 0 deletions src/module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ immutable CuModule
end

Base.unsafe_convert(::Type{CuModule_t}, mod::CuModule) = mod.handle
Base.:(==)(a::CuModule, b::CuModule) = a.handle == b.handle

"""
Unload a CUDA module.
Expand Down
1 change: 1 addition & 0 deletions src/module/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ immutable CuFunction
end

Base.unsafe_convert(::Type{CuFunction_t}, fun::CuFunction) = fun.handle
Base.:(==)(a::CuFunction, b::CuFunction) = a.handle == b.handle
2 changes: 1 addition & 1 deletion src/module/linker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ immutable CuLink
end

Base.unsafe_convert(::Type{CuLinkState_t}, link::CuLink) = link.handle
Base.show(io::IO,link::CuLink) = print(io, typeof(link), "(", link.handle, ")")
Base.:(==)(a::CuLink, b::CuLink) = a.handle == b.handle

"""
Complete a pending linker invocation, returning an output image.
Expand Down
1 change: 1 addition & 0 deletions src/stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ immutable CuStream
end

Base.unsafe_convert(::Type{CuStream_t}, s::CuStream) = s.handle
Base.:(==)(a::CuStream, b::CuStream) = a.handle == b.handle

function CuStream(flags::Integer=0)
handle_ref = Ref{CuStream_t}()
Expand Down

0 comments on commit 9124f30

Please sign in to comment.