Skip to content

Commit

Permalink
#540 renaming classes. Fixed missed updates for InputDataSerializatio…
Browse files Browse the repository at this point in the history
…nHelper.
  • Loading branch information
nicolay-r committed Dec 30, 2023
1 parent dae8f91 commit f90520c
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 19 deletions.
8 changes: 4 additions & 4 deletions arekit/common/docs/parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from arekit.common.docs.base import Document
from arekit.common.docs.parsed.base import ParsedDocument
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.batching import BatchingPipeline
from arekit.common.pipeline.base import BasePipelineLauncher
from arekit.common.pipeline.batching import BatchingPipelineLauncher
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.utils import BatchIterator
from arekit.common.text.parsed import BaseParsedText
Expand All @@ -25,7 +25,7 @@ def parse(doc, pipeline_items, parent_ppl_ctx=None, src_key="input"):
ctx = PipelineContext({src_key: doc.get_sentence(sent_ind)}, parent_ctx=parent_ppl_ctx)

# Apply all the operations.
BasePipeline.run(pipeline=pipeline_items, pipeline_ctx=ctx, src_key=src_key)
BasePipelineLauncher.run(pipeline=pipeline_items, pipeline_ctx=ctx, src_key=src_key)

# Collecting the result.
parsed_sentences.append(BaseParsedText(terms=ctx.provide("result")))
Expand All @@ -49,7 +49,7 @@ def parse_batch(doc, pipeline_items, batch_size, parent_ppl_ctx=None, src_key="i
parent_ctx=parent_ppl_ctx)

# Apply all the operations.
BatchingPipeline.run(pipeline=pipeline_items, pipeline_ctx=ctx, src_key=src_key)
BatchingPipelineLauncher.run(pipeline=pipeline_items, pipeline_ctx=ctx, src_key=src_key)

# Collecting the result.
parsed_sentences += [BaseParsedText(terms=result) for result in ctx.provide("result")]
Expand Down
2 changes: 1 addition & 1 deletion arekit/common/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from arekit.common.pipeline.items.base import BasePipelineItem


class BasePipeline:
class BasePipelineLauncher:

@staticmethod
def run(pipeline, pipeline_ctx, src_key=None):
Expand Down
2 changes: 1 addition & 1 deletion arekit/common/pipeline/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from arekit.common.pipeline.items.base import BasePipelineItem


class BatchingPipeline:
class BatchingPipelineLauncher:

@staticmethod
def run(pipeline, pipeline_ctx, src_key=None):
Expand Down
6 changes: 3 additions & 3 deletions arekit/contrib/utils/data/contents/opinions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from arekit.common.data.input.providers.contents import ContentsProvider
from arekit.common.linkage.base import LinkedDataWrapper
from arekit.common.linkage.text_opinions import TextOpinionsLinkage
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.base import BasePipelineLauncher
from arekit.common.pipeline.context import PipelineContext
from arekit.common.text_opinions.base import TextOpinion

Expand All @@ -14,7 +14,7 @@ def __init__(self, pipeline):
results in a TextOpinionLinkage instances.
pipeline: id -> ... -> TextOpinionLinkage[]
"""
assert(isinstance(pipeline, BasePipeline))
assert(isinstance(pipeline, list))
self.__pipeline = pipeline
self.__current_id = None

Expand All @@ -38,7 +38,7 @@ def from_doc_ids(self, doc_ids, idle_mode=False):
})

# Launching pipeline with the passed context
self.__pipeline.run(ctx)
BasePipelineLauncher.run(pipeline=self.__pipeline, pipeline_ctx=ctx)

for linkage in ctx.provide("result"):
assert(isinstance(linkage, LinkedDataWrapper))
Expand Down
3 changes: 1 addition & 2 deletions arekit/contrib/utils/pipelines/items/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.common.experiment.data_type import DataType
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.contrib.utils.serializer import InputDataSerializationHelper

Expand Down Expand Up @@ -31,7 +30,7 @@ def __init__(self, rows_provider, samples_io, save_labels_func, storage, **kwarg

def _serialize_iteration(self, data_type, pipeline, data_folding, doc_ids):
assert(isinstance(data_type, DataType))
assert(isinstance(pipeline, BasePipeline))
assert(isinstance(pipeline, list))
assert(isinstance(data_folding, dict) or data_folding is None)
assert(isinstance(doc_ids, list) or doc_ids is None)
assert(doc_ids is not None or data_folding is not None)
Expand Down
1 change: 0 additions & 1 deletion arekit/contrib/utils/pipelines/text_opinion/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from arekit.common.docs.parsed.providers.entity_service import EntityServiceProvider
from arekit.common.docs.parsed.service import ParsedDocumentService
from arekit.common.docs.parser import DocumentParsers
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.items.flatten import FlattenIterPipelineItem
from arekit.common.pipeline.items.map import MapPipelineItem
from arekit.common.pipeline.items.map_nested import MapNestedPipelineItem
Expand Down
3 changes: 1 addition & 2 deletions arekit/contrib/utils/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from arekit.common.data.input.repositories.base import BaseInputRepository
from arekit.common.data.input.repositories.sample import BaseInputSamplesRepository
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.pipeline.base import BasePipeline
from arekit.contrib.utils.data.contents.opinions import InputTextOpinionProvider

logger = logging.getLogger(__name__)
Expand All @@ -28,7 +27,7 @@ def create_samples_repo(keep_labels, rows_provider, storage):

@staticmethod
def fill_and_write(pipeline, repo, target, writer, doc_ids_iter, desc=""):
assert(isinstance(pipeline, BasePipeline))
assert(isinstance(pipeline, list))
assert(isinstance(doc_ids_iter, Iterable))
assert(isinstance(repo, BaseInputRepository))

Expand Down
10 changes: 5 additions & 5 deletions tests/text/test_nested_entities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.batching import BatchingPipeline
from arekit.common.pipeline.base import BasePipelineLauncher
from arekit.common.pipeline.batching import BatchingPipelineLauncher
from arekit.common.pipeline.context import PipelineContext
from arekit.contrib.utils.pipelines.items.text.entities_default import TextEntitiesParser

Expand All @@ -16,16 +16,16 @@ class TestNestedEntities(unittest.TestCase):

def test(self):

ctx = BasePipeline.run(pipeline=[TextEntitiesParser()],
pipeline_ctx=PipelineContext({"result": self.s.split()}))
ctx = BasePipelineLauncher.run(pipeline=[TextEntitiesParser()],
pipeline_ctx=PipelineContext({"result": self.s.split()}))
parsed_text = ctx.provide("result")

print(parsed_text)

def test_batched(self):

# Compose a single batch with two sentences.
ctx = BatchingPipeline.run(
ctx = BatchingPipelineLauncher.run(
pipeline=[TextEntitiesParser()],
pipeline_ctx=PipelineContext({"result": [self.s.split(), self.s.split()]}))
parsed_text = ctx.provide("result")
Expand Down

0 comments on commit f90520c

Please sign in to comment.