Skip to content

Commit

Permalink
#539 input feeding support
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 29, 2023
1 parent acba0a6 commit 901c2ae
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 15 deletions.
12 changes: 6 additions & 6 deletions arekit/common/docs/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class DocumentParsers(object):

@staticmethod
def parse(doc, pipeline_items, parent_ppl_ctx=None):
def parse(doc, pipeline_items, parent_ppl_ctx=None, src_key="input"):
""" This document parser is based on single text parts (sentences)
that passes sequentially through the pipeline of transformations.
"""
Expand All @@ -24,18 +24,18 @@ def parse(doc, pipeline_items, parent_ppl_ctx=None):
for sent_ind in range(doc.SentencesCount):

# Composing the context from a single sentence.
ctx = PipelineContext({"input": doc.get_sentence(sent_ind)}, parent_ctx=parent_ppl_ctx)
ctx = PipelineContext({src_key: doc.get_sentence(sent_ind)}, parent_ctx=parent_ppl_ctx)

# Apply all the operations.
pipeline.run(ctx)
pipeline.run(ctx, src_key=src_key)

# Collecting the result.
parsed_sentences.append(BaseParsedText(terms=ctx.provide("result")))

return ParsedDocument(doc_id=doc.ID, parsed_sentences=parsed_sentences)

@staticmethod
def parse_batch(doc, pipeline_items, batch_size, parent_ppl_ctx=None):
def parse_batch(doc, pipeline_items, batch_size, parent_ppl_ctx=None, src_key="input"):
""" This document parser is based on batch of sentences.
"""
assert(isinstance(batch_size, int) and batch_size > 0)
Expand All @@ -49,11 +49,11 @@ def parse_batch(doc, pipeline_items, batch_size, parent_ppl_ctx=None):
for batch in BatchIterator(lst=list(range(doc.SentencesCount)), batch_size=batch_size):

# Composing the context from a single sentence.
ctx = PipelineContext({"input": [doc.get_sentence(s_ind) for s_ind in batch]},
ctx = PipelineContext({src_key: [doc.get_sentence(s_ind) for s_ind in batch]},
parent_ctx=parent_ppl_ctx)

# Apply all the operations.
pipeline.run(ctx)
pipeline.run(ctx, src_key=src_key)

# Collecting the result.
parsed_sentences += [BaseParsedText(terms=result) for result in ctx.provide("result")]
Expand Down
8 changes: 5 additions & 3 deletions arekit/common/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ def __init__(self, pipeline):
assert(isinstance(pipeline, list))
self._pipeline = pipeline

def run(self, pipeline_ctx):
def run(self, pipeline_ctx, src_key=None):
assert(isinstance(pipeline_ctx, PipelineContext))
assert(isinstance(src_key, str) or src_key is None)

for item in filter(lambda itm: itm is not None, self._pipeline):
for ind, item in enumerate(filter(lambda itm: itm is not None, self._pipeline)):
assert(isinstance(item, BasePipelineItem))
item_result = item.apply(input_data=item.get_source(pipeline_ctx), pipeline_ctx=pipeline_ctx)
input_data = item.get_source(pipeline_ctx, force_key=src_key if src_key is not None and ind == 0 else None)
item_result = item.apply(input_data=input_data, pipeline_ctx=pipeline_ctx)
pipeline_ctx.update(param=item.ResultKey, value=item_result, is_new_key=False)

return pipeline_ctx
Expand Down
11 changes: 7 additions & 4 deletions arekit/common/pipeline/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@

class BatchingPipeline(BasePipeline):

def run(self, pipeline_ctx):
def run(self, pipeline_ctx, src_key=None):
assert(isinstance(pipeline_ctx, PipelineContext))
assert(isinstance(src_key, str) or src_key is None)

for item in filter(lambda itm: itm is not None, self._pipeline):
for ind, item in enumerate(filter(lambda itm: itm is not None, self._pipeline)):
assert (isinstance(item, BasePipelineItem))

# Handle the content of the batch or batch itself.
content = item.get_source(pipeline_ctx, call_func=item.SupportBatching,
force_key=src_key if ind == 0 else None)

if item.SupportBatching:
handled_batch = item.get_source(pipeline_ctx)
handled_batch = content
else:
content = item.get_source(pipeline_ctx, call_func=False)
handled_batch = [item._src_func(i) if item._src_func is not None else i for i in content]

# At present, each batch represent a list of contents.
Expand Down
4 changes: 2 additions & 2 deletions arekit/common/pipeline/items/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def SupportBatching(self):
"""
return False

def get_source(self, src_ctx, call_func=True):
def get_source(self, src_ctx, call_func=True, force_key=None):
""" Extract input element for processing.
"""
assert(isinstance(src_ctx, PipelineContext))
Expand All @@ -32,7 +32,7 @@ def get_source(self, src_ctx, call_func=True):
return None

# Extracting actual source.
src_data = src_ctx.provide(self.__src_key)
src_data = src_ctx.provide(self.__src_key if force_key is None else force_key)
if self._src_func is not None and call_func:
src_data = self._src_func(src_data)

Expand Down

0 comments on commit 901c2ae

Please sign in to comment.