Skip to content

Commit

Permalink
#540 fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 30, 2023
1 parent 998c2a7 commit dae8f91
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 28 deletions.
8 changes: 2 additions & 6 deletions arekit/common/docs/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@ def parse(doc, pipeline_items, parent_ppl_ctx=None, src_key="input"):
assert(isinstance(pipeline_items, list))
assert(isinstance(parent_ppl_ctx, PipelineContext) or parent_ppl_ctx is None)

pipeline = BasePipeline(pipeline_items)

parsed_sentences = []
for sent_ind in range(doc.SentencesCount):

# Composing the context from a single sentence.
ctx = PipelineContext({src_key: doc.get_sentence(sent_ind)}, parent_ctx=parent_ppl_ctx)

# Apply all the operations.
pipeline.run(ctx, src_key=src_key)
BasePipeline.run(pipeline=pipeline_items, pipeline_ctx=ctx, src_key=src_key)

# Collecting the result.
parsed_sentences.append(BaseParsedText(terms=ctx.provide("result")))
Expand All @@ -43,8 +41,6 @@ def parse_batch(doc, pipeline_items, batch_size, parent_ppl_ctx=None, src_key="i
assert(isinstance(pipeline_items, list))
assert(isinstance(parent_ppl_ctx, PipelineContext) or parent_ppl_ctx is None)

pipeline = BatchingPipeline(pipeline_items)

parsed_sentences = []
for batch in BatchIterator(lst=list(range(doc.SentencesCount)), batch_size=batch_size):

Expand All @@ -53,7 +49,7 @@ def parse_batch(doc, pipeline_items, batch_size, parent_ppl_ctx=None, src_key="i
parent_ctx=parent_ppl_ctx)

# Apply all the operations.
pipeline.run(ctx, src_key=src_key)
BatchingPipeline.run(pipeline=pipeline_items, pipeline_ctx=ctx, src_key=src_key)

# Collecting the result.
parsed_sentences += [BaseParsedText(terms=result) for result in ctx.provide("result")]
Expand Down
14 changes: 4 additions & 10 deletions arekit/common/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,18 @@
from arekit.common.pipeline.items.base import BasePipelineItem


class BasePipeline(object):
class BasePipeline:

def __init__(self, pipeline):
@staticmethod
def run(pipeline, pipeline_ctx, src_key=None):
assert(isinstance(pipeline, list))
self._pipeline = pipeline

def run(self, pipeline_ctx, src_key=None):
assert(isinstance(pipeline_ctx, PipelineContext))
assert(isinstance(src_key, str) or src_key is None)

for ind, item in enumerate(filter(lambda itm: itm is not None, self._pipeline)):
for ind, item in enumerate(filter(lambda itm: itm is not None, pipeline)):
assert(isinstance(item, BasePipelineItem))
input_data = item.get_source(pipeline_ctx, force_key=src_key if src_key is not None and ind == 0 else None)
item_result = item.apply(input_data=input_data, pipeline_ctx=pipeline_ctx)
pipeline_ctx.update(param=item.ResultKey, value=item_result, is_new_key=False)

return pipeline_ctx

def append(self, item):
assert(isinstance(item, BasePipelineItem))
self._pipeline.append(item)
9 changes: 5 additions & 4 deletions arekit/common/pipeline/batching.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.items.base import BasePipelineItem


class BatchingPipeline(BasePipeline):
class BatchingPipeline:

def run(self, pipeline_ctx, src_key=None):
@staticmethod
def run(pipeline, pipeline_ctx, src_key=None):
assert(isinstance(pipeline, list))
assert(isinstance(pipeline_ctx, PipelineContext))
assert(isinstance(src_key, str) or src_key is None)

for ind, item in enumerate(filter(lambda itm: itm is not None, self._pipeline)):
for ind, item in enumerate(filter(lambda itm: itm is not None, pipeline)):
assert (isinstance(item, BasePipelineItem))

# Handle the content of the batch or batch itself.
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/utils/pipelines/text_opinion/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def text_opinion_extraction_pipeline(pipeline_items, get_doc_by_id_func, annotat
extra_filters = [] if text_opinion_filters is None else text_opinion_filters
actual_text_opinion_filters = [FrameworkLimitationsTextOpinionFilter()] + extra_filters

return BasePipeline([
return [
# (doc_id) -> (doc)
MapPipelineItem(map_func=lambda doc_id: get_doc_by_id_func(doc_id)),

Expand All @@ -88,4 +88,4 @@ def text_opinion_extraction_pipeline(pipeline_items, get_doc_by_id_func, annotat

# linkages[] -> linkages
FlattenIterPipelineItem()
])
]
11 changes: 5 additions & 6 deletions tests/text/test_nested_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@ class TestNestedEntities(unittest.TestCase):

def test(self):

ppl = BasePipeline(pipeline=[TextEntitiesParser()])

ctx = ppl.run(PipelineContext({"result": self.s.split()}))
ctx = BasePipeline.run(pipeline=[TextEntitiesParser()],
pipeline_ctx=PipelineContext({"result": self.s.split()}))
parsed_text = ctx.provide("result")

print(parsed_text)

def test_batched(self):

ppl = BatchingPipeline(pipeline=[TextEntitiesParser()])

# Compose a single batch with two sentences.
ctx = ppl.run(pipeline_ctx=PipelineContext({"result": [self.s.split(), self.s.split()]}))
ctx = BatchingPipeline.run(
pipeline=[TextEntitiesParser()],
pipeline_ctx=PipelineContext({"result": [self.s.split(), self.s.split()]}))
parsed_text = ctx.provide("result")

print(parsed_text)
Expand Down

0 comments on commit dae8f91

Please sign in to comment.