Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunwoongko committed Dec 2, 2022
1 parent 7f5ab4a commit 596582e
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 44 deletions.
31 changes: 18 additions & 13 deletions datrie/_datrie_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,17 @@ class DoubleArrayTrie:
max_length: int
error: int

def __init__(self):
def __init__(self, data: Dict[str, Any] = None):
self.check = None
self.base = None
self.used = None
self.size = 0
self.alloc_size = 0
self.error = 0

if data is not None:
self.build_with_dict(data)

def get_unit_size(self) -> int:
return self.UNIT_SIZE

Expand Down Expand Up @@ -341,21 +344,23 @@ def common_prefix_search_with_value(self, key: str):

return result

def get_value(self, index) -> Any:
def _get_value(self, index) -> Any:
return self.v[index]

def get(self, key) -> Any:
index = self.exact_match_search(key)
if index >= 0:
return self.get_value(index)
def get(self, key, prefix_search=False) -> Any:
if prefix_search:
indices = self.common_prefix_search(key)
results = [self._get_value(i) for i in indices if i >= 0]
if len(results) != 0:
return results
else:
index = self.exact_match_search(key)
if index >= 0:
return self._get_value(index)
return None

def save(self, filename: str):

data = self.__dict__
check = self.check
print(data.keys())
print(len(check))
def __getitem__(self, item: str) -> Any:
return self.get(item)

def save(self, filename: str):
import pickle
Expand All @@ -364,7 +369,7 @@ def save(self, filename: str):
pickle.dump(self.__dict__, fp)

@classmethod
def load(cls, filename: str):
def load(cls, filename: str) -> "DoubleArrayTrie":
import pickle

with open(filename, mode="rb") as fp:
Expand Down
62 changes: 31 additions & 31 deletions datrie/_datrie_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,26 @@ class SearchState:
next = None


class SinpleDoubleArrayTrie(object):
class SimpleDoubleArrayTrie(object):
def __init__(self, alphabet_length=256):
super().__init__()
self.alphabet_length = alphabet_length
self.free_positions = SortedSet()
self.base = [INITIAL_ROOT_BASE]
self.check = [ROOT_CHECK_VALUE]

def size(self):
def _size(self):
return len(self.base)

def set_base(self, position, value):
def _set_base(self, position, value):
self.base[position] = value
if value == EMPTY_VALUE:
self.free_positions.add(position)
else:
if position in self.free_positions:
self.free_positions.remove(position)

def set_check(self, position, value):
def _set_check(self, position, value):
self.check[position] = value
if value == EMPTY_VALUE:
self.free_positions.add(position)
Expand All @@ -49,7 +49,7 @@ def set_check(self, position, value):

def _next_available_hop(self, for_value):
while self.free_positions.bisect_right(for_value) >= len(self.free_positions):
self._ensure_reachable_index(self.size() + 1)
self._ensure_reachable_index(self._size() + 1)

result = (
self.free_positions[self.free_positions.bisect_right(for_value)] - for_value
Expand All @@ -59,10 +59,10 @@ def _next_available_hop(self, for_value):
return result

def _ensure_reachable_index(self, limit):
while self.size() <= limit:
while self._size() <= limit:
self.base.append(EMPTY_VALUE)
self.check.append(EMPTY_VALUE)
self.free_positions.add(self.size() - 1)
self.free_positions.add(self._size() - 1)

def _find_consecutive_free(self, amount: int):
assert amount >= 0
Expand Down Expand Up @@ -103,8 +103,8 @@ def _next_available_move(self, values: SortedSet):
if possible - min_value >= 0:
return possible - min_value

self._ensure_reachable_index(self.size() + needed_positions)
return self.size() - needed_positions - min_value
self._ensure_reachable_index(self._size() + needed_positions)
return self._size() - needed_positions - min_value

def _add_to_trie(self, inputs):
changed = False
Expand All @@ -118,7 +118,7 @@ def _add_to_trie(self, inputs):
state_base = self.base[state]

if i > 0 and state_base == LEAF_BASE_VALUE:
self.set_base(transition, self._next_available_hop(c))
self._set_base(transition, self._next_available_hop(c))
changed = True
else:
assert self.base[state] >= 0
Expand All @@ -128,12 +128,12 @@ def _add_to_trie(self, inputs):

self._ensure_reachable_index(transition)
if self.check[transition] == EMPTY_VALUE:
self.set_check(transition, state)
self._set_check(transition, state)
if i == len(inputs) - 1:
self.set_base(transition, LEAF_BASE_VALUE)
self._set_base(transition, LEAF_BASE_VALUE)
changed = True
else:
self.set_base(transition, self._next_available_hop(inputs[i + 1]))
self._set_base(transition, self._next_available_hop(inputs[i + 1]))
changed = True
else:
if self.check[transition] != state:
Expand All @@ -151,7 +151,7 @@ def _resolve_conflict(self, s, new_value):

for c in range(self.alphabet_length):
temp_next = self._walk(s, c)
if 0 <= temp_next < self.size() and self.check[temp_next] == s:
if 0 <= temp_next < self._size() and self.check[temp_next] == s:
values.add(c)

new_location = self._next_available_move(values)
Expand All @@ -160,39 +160,39 @@ def _resolve_conflict(self, s, new_value):
for i in range(len(values)):
c = values[i]
temp_next = self._walk(s, c)
assert temp_next < self.size()
assert temp_next < self._size()
assert self.check[temp_next] == s
assert self.check[new_location + c] == EMPTY_VALUE
self.set_check(new_location + c, s)
self._set_check(new_location + c, s)

assert self.base[new_location + c] == EMPTY_VALUE
self.set_base(new_location + c, self.base[self._walk(s, c)])
self._set_base(new_location + c, self.base[self._walk(s, c)])
self._update_child_move(s, c)

if self.base[self._walk(s, c)] != LEAF_BASE_VALUE:
for d in range(self.alphabet_length):
temp = self._walk(s, c)
temp_next_child = self._walk(temp, d)
if temp_next_child < self.size() and self.check[
if temp_next_child < self._size() and self.check[
temp_next_child
] == self._walk(s, c):
temp = self._walk(s, c)
self.set_check(self._walk(temp, d), new_location + c)
self._set_check(self._walk(temp, d), new_location + c)

elif temp_next >= self.size():
elif temp_next >= self._size():
break

self.set_base(self._walk(s, c), EMPTY_VALUE)
self.set_check(self._walk(s, c), EMPTY_VALUE)
self._set_base(self._walk(s, c), EMPTY_VALUE)
self._set_check(self._walk(s, c), EMPTY_VALUE)

self.set_base(s, new_location)
self._set_base(s, new_location)

def _is_walkable(self, state: int, c: int):
if not (state < self.size()):
if not (state < self._size()):
return False

transition = self._walk(state, c)
return transition < self.size() and self.check[transition] == state
return transition < self._size() and self.check[transition] == state

def _walk(self, state: int, c: int):
return self.base[state] + c
Expand Down Expand Up @@ -260,19 +260,19 @@ def _remove_from_trie(self, inputs):
for i in range(delete_from_index, len(inputs)):
c = inputs[i]
transition = self._walk(state, c)
self.set_base(state, EMPTY_VALUE)
self.set_base(state, EMPTY_VALUE)
self._set_base(state, EMPTY_VALUE)
self._set_base(state, EMPTY_VALUE)
state = transition
return True
else:
return False

@staticmethod
def _create_unicode(inputs, values=None):
return [ord(s) - ord("a") for s in inputs]
def _create_unicode(inputs):
return [s for s in inputs.encode("utf-8")]

def put(self, inputs, values):
self._add_to_trie(self._create_unicode(inputs, values))
def put(self, inputs):
self._add_to_trie(self._create_unicode(inputs))

def find(self, inputs):
return self._contains_prefix(self._create_unicode(inputs))
Expand Down
Empty file added tests/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions tests/_test_datrie_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from datrie._datrie_impl import DoubleArrayTrie

if __name__ == "__main__":
trie = DoubleArrayTrie(
{
"아버지": "NNG",
"가": "JKS",
"방": "NNG",
"에": "JKB",
"들어오": "VV",
"신다": "EP+EF",
".": "SP",
}
)

print(trie["아버지"]) # NNG
print(trie["신다"]) # EP+EF

filename = "file.dat"
trie.save(filename)

trie2 = DoubleArrayTrie.load(filename)
print(trie2["아버지"]) # NNG
print(trie2["신다"]) # EP+EF
15 changes: 15 additions & 0 deletions tests/_test_datrie_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from datrie._datrie_simple import SimpleDoubleArrayTrie

if __name__ == "__main__":
trie = SimpleDoubleArrayTrie()
trie.put("ABC")
trie.put("DEF")

print(f"AB: {trie.find('AB')}") # PARTIAL_MATCH
print(f"ABC: {trie.find('ABC')}") # PERFECT_MATCH
print(f"DE: {trie.find('DE')}") # PARTIAL_MATCH
print(f"DEF: {trie.find('DEF')}") # PERFECT_MATCH
print(f"EF: {trie.find('EF')}") # NOT_FOUND

trie.remove("ABC")
print(f"ABC: {trie.find('ABC')}") # NOT_FOUND

0 comments on commit 596582e

Please sign in to comment.