Skip to content

Commit

Permalink
Updated batch iterator implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jan 27, 2024
1 parent af1a351 commit 0911d08
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
2 changes: 1 addition & 1 deletion arekit/common/docs/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def parse_batch(doc, pipeline_items, batch_size, parent_ppl_ctx=None, src_key="i

parsed_sentences = []

data_it = BatchIterator(lst=list(range(doc.SentencesCount)), batch_size=batch_size)
data_it = BatchIterator(data_iter=iter(range(doc.SentencesCount)), batch_size=batch_size)
progress_it = tqdm(data_it, total=round(doc.SentencesCount / batch_size), disable=not show_progress)

for batch in progress_it:
Expand Down
29 changes: 21 additions & 8 deletions arekit/common/pipeline/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
class BatchIterator:

def __init__(self, lst, batch_size):
assert(isinstance(lst, list))
def __init__(self, data_iter, batch_size, end_value=None):
assert(isinstance(batch_size, int) and batch_size > 0)
self.__lst = lst
assert(callable(end_value) or end_value is None)
self.__data_iter = data_iter
self.__index = 0
self.__batch_size = batch_size
self.__end_value = end_value

def __iter__(self):
return self

def __next__(self):
if self.__index < len(self.__lst):
batch = self.__lst[self.__index:self.__index + self.__batch_size]
self.__index += self.__batch_size
return batch
else:
buffer = []
while True:
try:
data = next(self.__data_iter)
except StopIteration:
break
buffer.append(data)
if len(buffer) == self.__batch_size:
break

if len(buffer) > 0:
self.__index += 1
return buffer

if self.__end_value is None:
raise StopIteration
else:
return self.__end_value()
11 changes: 11 additions & 0 deletions tests/common/test_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import unittest

from arekit.common.pipeline.utils import BatchIterator


class TestBalancing(unittest.TestCase):

def test(self):
batch_it = BatchIterator(data_iter=iter(range(10)), batch_size=3)
for a in batch_it:
print(a)

0 comments on commit 0911d08

Please sign in to comment.