Skip to content

Commit

Permalink
More initial changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen committed Nov 27, 2019
1 parent eddc0a0 commit 66ad99a
Show file tree
Hide file tree
Showing 5 changed files with 320 additions and 131 deletions.
131 changes: 0 additions & 131 deletions src/SmallBitSet.jl

This file was deleted.

18 changes: 18 additions & 0 deletions src/StackCollections.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module t

# StackSet, no params, 0-63
# StackBitSet, no params, contain StackSet and offset
# StackVector, length param, as now implemented

struct Unsafe end
const unsafe = Unsafe()

abstract type AbstractStackSet <: AbstractSet{Int} end

include("stackarray.jl")
include("stackset.jl")
include("stackbitset.jl")

export StackArray, StackBitSet, setindex, StackSet, push, delete, complement, isdisjoint

end
84 changes: 84 additions & 0 deletions src/stackarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
struct StackVector{L} <: AbstractVector{Bool}
x::UInt

function StackVector{L}(x::UInt, ::Unsafe) where L
L isa Int || throw(TypeError(:StackVector, "", Int, typeof(L)))
((L Sys.WORD_SIZE) & (L > -1)) || throw(DomainError(M, "L must be 0:$(Sys.WORD_SIZE)"))
new(x)
end
end

mask(L) = UInt(1) << L - 1
Base.size(s::StackVector) = (length(s),)
Base.length(s::StackVector{L}) where L = L

StackVector{L}() where L = StackVector{L}(UInt(0), unsafe)
StackVector() = StackVector{Sys.WORD_SIZE}()

function Base.getindex(s::StackVector, i::Int)
@boundscheck checkbounds(s, i)
return (s.x >>> unsigned(i-1)) & 1 == 1
end

function setindex(s::StackVector, v::Bool, i::Int)
@boundscheck checkbounds(s, i)
u = 1 << unsigned(i-1)
typeof(s)(ifelse(v, s.x | u, s.x & ~u), unsafe)
end

function Base.iterate(s::StackVector, i::Int=1)
i > length(s) && return nothing
@inbounds (s[i], i+1)
end

Base.in(v::Bool, s::StackVector{L}) where L = !iszero(ifelse(v, s.x, s.x mask(L)))

Base.isempty(s::StackVector{L}) where L = iszero(L)

function Base.minimum(s::StackVector{L}) where L
isempty(s) && throw(ArgumentError("cannot take minimum of empty collection"))
ifelse(s.x == mask(L), true, false)
end

function Base.maximum(s::StackVector{L}) where L
isempty(s) && throw(ArgumentError("cannot take maximum of empty collection"))
ifelse(iszero(s.x), false, true)
end

Base.sum(s::StackVector) = count_ones(s.x)

function Base.convert(::Type{BitVector}, s::StackVector)
b = trues(length(s))
!isempty(s) && @inbounds b.chunks[1] = s.x
b
end

Base.:!(s::StackVector) = typeof(s)(s.x mask, unsafe)

function Base.filter(f::Function, s::StackVector{L}) where L
ft::Bool = f(true)
ff::Bool = f(false)
ft & ff && return typeof(s)(mask(L), unsafe)
!(ft | ff) && return typeof(s)(zero(UInt), unsafe)
ff && return !s
return s
end

# Need to test this TODO
function Base.reverse(s::StackVector)
x = s.x
x = ((x & 0xaaaaaaaaaaaaaaaa) >>> 1) | ((x & 0x5555555555555555) << 1)
x = ((x & 0xcccccccccccccccc) >>> 2) | ((x & 0x3333333333333333) << 2)
x = ((x & 0xf0f0f0f0f0f0f0f0) >>> 4) | ((x & 0x0f0f0f0f0f0f0f0f) << 4)
x = ((x & 0xff00ff00ff00ff00) >>> 8) | ((x & 0x00ff00ff00ff00ff) << 8)
x = ((x & 0xffff0000ffff0000) >>> 16) | ((x & 0x0000ffff0000ffff) << 16)
x = ((x & 0xffffffff00000000) >>> 32) | ((x & 0x00000000ffffffff) << 32)
x >>>= sizeof(UInt) << 3 - length(s)
typeof(s)(x, unsafe)
end

# 1 rotate k right, & with mask, save as x
# rotate 64-L+k right, & with ((1<<k-1) << (L-k)), save as y
# result is y | x
# alternatively, use naive rot
#Base.circshift(s::StackVector)
119 changes: 119 additions & 0 deletions src/stackbitset.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
struct StackBitSet <: AbstractStackSet
set::StackSet
offset::Int # lowest value in set, 0 if empty elements

StackBitSet(set::StackSet, offset::Int, ::Unsafe) = new(set, offset)
end

function StackBitSet(set::StackSet, offset::Int)
s = StackBitSet(set, offset, unsafe)
if !(isempty(set) | isodd(set.x))
s = normalized(s)
end
return s
end

StackBitSet() = StackBitSet(StackSet(), 0, unsafe)
Base.empty(x::StackBitSet) = StackBitSet()
Base.isempty(x::StackBitSet) = isempty(x.set)

function StackBitSet(itr)
d = StackBitSet()
for i in itr
d = push(d, convert(Int, i))
end
d
end

@noinline function throw_StackBitSet_range_err()
throw(ArgumentError("StackSet can not contain values differing" *
"by more than $(Sys.WORD_SIZE-1)"))
end

function Base.iterate(s::StackBitSet, i::Int=0)
it = iterate(s.set, i)
it === nothing && return nothing
return it[1]+s.offset, it[2]
end

function push(s::StackBitSet, x::Int, ::Unsafe)
# Fake old offset for empty s, so we can increment s.offset
oldoffset = ifelse(isempty(s), x, s.offset)
newoffset = min(oldoffset, x)
lshift = unsigned(s.offset - newoffset) & 63
newset = push(StackSet(s.set.x << lshift), x-newoffset, unsafe)
return StackBitSet(newset, newoffset, unsafe)
end

function push(s::StackBitSet, x::Int)
!isempty(s) & (abs(x-s.offset) > 63) && throw_StackBitSet_range_err()
return push(s, x, unsafe)
end

Base.length(x::StackSet) = length(x.set)
Base.maximum(x::StackBitSet) = maximum(x.set + x.offset)
Base.in(x::Int, s::StackBitSet) = in(x-s.offset, s.set)

function Base.filter(pred, s::StackBitSet)
r = StackSet()
for i in s.set
pred(i+s.offset) && (r = push(r, i, unsafe))
end
normalized(StackBitSet(r, s.offset, unsafe))
end

pop(s::StackBitSet, v::Int) = in(s, v) ? delete(s, v, unsafe) : throw(KeyError(v))

function delete(s::StackBitSet, v::Int, ::Unsafe)
normalized(delete(s.set, v-s.offset, unsafe), s.offset)
end
function delete(s::StackBitSet, v::Int)
normalized(delete(s.set, v-s.offset), s.offset)
end

function Base.intersect(x::StackBitSet, y::StackBitSet)
new_x_set = trunc_offset_stackset(x, y)
normalized(StackBitSet(intersect(new_x_set, y.set), y.offset, unsafe))
end

function Base.setdiff(x::StackBitSet, y::StackBitSet)
new_y_set = trunc_offset_stackset(y, x)
normalized(StackBitSet(setdiff(x.set, new_y_set), x.offset, unsafe))
end
################
function trunc_offset_stackset(from::StackBitSet, to::StackBitSet)
return StackSet(from.set.x >>> (to.offset - from.offset))
end

function normalized(s::StackBitSet)
# Update offset and bitshift
rshift = trailing_zeros(s.set.x) & 63
s2 = StackBitSet(StackSet(s.set.x >>> unsigned(rshift)), s.offset + rshift, unsafe)
return ifelse(isempty(s), StackBitSet(), s2)
end

# Return sets with lower offset, or non-empty's offset
function offset_to_lower(smaller::StackBitSet, bigger::StackBitSet)
if (isempty(smaller) | isempty(bigger))
return smaller.set, bigger.set
end
lshift = unsigned(bigger.offset - smaller.offset)
leading_zeros(bigger.set.x) < lshift && throw_StackBitSet_range_err()
shifted = bigger.set.x << (lshift & 63)
return smaller.set, StackSet(shifted)
end

function Base.union(x::StackBitSet, y::StackBitSet)
(smaller, bigger) = ifelse(x.offset < y.offset, (x, y), (y, x))
sm_set, bg_set = offset_to_lower(smaller, bigger)
return StackBitSet(union(sm_set, bg_set), smaller.offset, unsafe)
end

function Base.symdiff(x::StackBitSet, y::StackBitSet)
(smaller, bigger) = ifelse(x.offset < y.offset, (x, y), (y, x))
sm_set, bg_set = offset_to_lower(smaller, bigger)
sbs = StackBitSet(symdiff(sm_set, bg_set), smaller.offset, unsafe)
return normalized(sbs)
end

##########################
Loading

0 comments on commit 66ad99a

Please sign in to comment.