From 958084c0a11b31f16295d5970df854ed20479410 Mon Sep 17 00:00:00 2001 From: Nicolay Rusnachenko Date: Wed, 27 Dec 2023 20:30:54 +0000 Subject: [PATCH] #539 done. Now pipelines way more flexible. Source could be customized. #535 related --- arekit/common/docs/entities_grouping.py | 3 +- arekit/common/docs/objects_parser.py | 13 ++-- arekit/common/docs/parser.py | 12 +--- arekit/common/pipeline/base.py | 12 ++-- arekit/common/pipeline/context.py | 6 +- arekit/common/pipeline/items/base.py | 30 ++++++++ arekit/common/pipeline/items/flatten.py | 6 +- arekit/common/pipeline/items/handle.py | 3 +- arekit/common/pipeline/items/iter.py | 3 +- arekit/common/pipeline/items/map.py | 3 +- arekit/common/pipeline/items/map_nested.py | 4 ++ arekit/common/text/parser.py | 11 ++- arekit/contrib/source/brat/entities/parser.py | 18 ++--- .../source/ruattitudes/entity/parser.py | 4 +- .../contrib/utils/data/contents/opinions.py | 12 +++- .../utils/pipelines/items/sampling/base.py | 15 ++-- .../pipelines/items/sampling/networks.py | 5 +- .../pipelines/items/text/entities_default.py | 4 +- .../utils/pipelines/items/text/frames.py | 5 +- .../pipelines/items/text/frames_lemmatized.py | 4 +- .../pipelines/items/text/frames_negation.py | 3 +- .../pipelines/items/text/terms_splitter.py | 3 +- .../utils/pipelines/items/text/tokenizer.py | 5 +- .../utils/pipelines/items/text/translator.py | 3 +- tests/README.md | 2 + tests/contrib/networks/doc.py | 2 +- tests/contrib/networks/indices_feature.py | 2 +- tests/contrib/networks/test_input_features.py | 12 ++-- tests/contrib/networks/utils.py | 68 +++++++++++++++++++ tests/contrib/source/doc.py | 2 +- tests/contrib/source/test_labels.py | 2 +- tests/contrib/source/test_ruattitudes.py | 14 ++-- tests/contrib/source/test_rusentiframes.py | 2 +- .../contrib/source/test_rusentiframes_stat.py | 2 +- tests/contrib/source/test_show_frames_stat.py | 2 +- tests/contrib/utils/test_csv_stream_write.py | 2 +- tests/contrib/utils/test_frames_annotation.py | 4 +- tests/contrib/utils/test_text_parser.py | 22 +++--- tests/text/linked_opinions.py | 43 ------------ tests/text/test_nested_entities.py | 6 +- tests/text/utils.py | 18 ----- .../test_tutorial_pipeline_sampling_bert.py | 12 ++-- ...test_tutorial_pipeline_sampling_network.py | 9 +-- .../test_tutorial_pipeline_sampling_prompt.py | 12 ++-- ...torial_pipeline_text_opinion_annotation.py | 15 +++- .../test_tutorial_pipeline_text_parser.py | 2 +- 46 files changed, 251 insertions(+), 191 deletions(-) delete mode 100644 tests/text/linked_opinions.py diff --git a/arekit/common/docs/entities_grouping.py b/arekit/common/docs/entities_grouping.py index 320e4a78..466a5c6d 100644 --- a/arekit/common/docs/entities_grouping.py +++ b/arekit/common/docs/entities_grouping.py @@ -4,8 +4,9 @@ class EntitiesGroupingPipelineItem(BasePipelineItem): - def __init__(self, value_to_group_id_func): + def __init__(self, value_to_group_id_func, **kwargs): assert(callable(value_to_group_id_func)) + super(EntitiesGroupingPipelineItem, self).__init__(**kwargs) self.__value_to_group_id_func = value_to_group_id_func def apply_core(self, input_data, pipeline_ctx): diff --git a/arekit/common/docs/objects_parser.py b/arekit/common/docs/objects_parser.py index 612b0393..935c5b62 100644 --- a/arekit/common/docs/objects_parser.py +++ b/arekit/common/docs/objects_parser.py @@ -1,29 +1,28 @@ from arekit.common.pipeline.items.base import BasePipelineItem from arekit.common.text.partitioning.base import BasePartitioning -from arekit.common.pipeline.context import PipelineContext class SentenceObjectsParserPipelineItem(BasePipelineItem): - def __init__(self, partitioning): + def __init__(self, partitioning, **kwargs): assert(isinstance(partitioning, BasePartitioning)) + super(SentenceObjectsParserPipelineItem, self).__init__(**kwargs) self.__partitioning = partitioning # region protected - def _get_text(self, pipeline_ctx): + def _get_text(self, sentence): return None - def _get_parts_provider_func(self, input_data, pipeline_ctx): + def _get_parts_provider_func(self, sentence): raise NotImplementedError() # endregion def apply_core(self, input_data, pipeline_ctx): - assert(isinstance(pipeline_ctx, PipelineContext)) - external_input = self._get_text(pipeline_ctx) + external_input = self._get_text(input_data) actual_input = input_data if external_input is None else external_input - parts_it = self._get_parts_provider_func(input_data=actual_input, pipeline_ctx=pipeline_ctx) + parts_it = self._get_parts_provider_func(input_data) return self.__partitioning.provide(text=actual_input, parts_it=parts_it) # region base diff --git a/arekit/common/docs/parser.py b/arekit/common/docs/parser.py index 28633ad4..06883d3c 100644 --- a/arekit/common/docs/parser.py +++ b/arekit/common/docs/parser.py @@ -16,19 +16,9 @@ def parse(doc, text_parser, parent_ppl_ctx=None): assert(isinstance(text_parser, BaseTextParser)) assert(isinstance(parent_ppl_ctx, PipelineContext) or parent_ppl_ctx is None) - parsed_sentences = [text_parser.run(input_data=DocumentParser.__get_sent(doc, sent_ind).Text, - params_dict=DocumentParser.__create_ppl_params(doc=doc, sent_ind=sent_ind), + parsed_sentences = [text_parser.run(params_dict={"input": DocumentParser.__get_sent(doc, sent_ind)}, parent_ctx=parent_ppl_ctx) for sent_ind in range(doc.SentencesCount)] return ParsedDocument(doc_id=doc.ID, parsed_sentences=parsed_sentences) - - @staticmethod - def __create_ppl_params(doc, sent_ind): - assert(isinstance(doc, Document)) - return { - "s_ind": sent_ind, # sentence index. (as Metadata) - "doc_id": doc.ID, # document index. (as Metadata) - "sentence": DocumentParser.__get_sent(doc, sent_ind), # Required for special sources. - } diff --git a/arekit/common/pipeline/base.py b/arekit/common/pipeline/base.py index 56c597f5..9a2502ec 100644 --- a/arekit/common/pipeline/base.py +++ b/arekit/common/pipeline/base.py @@ -8,17 +8,15 @@ def __init__(self, pipeline): assert(isinstance(pipeline, list)) self.__pipeline = pipeline - def run(self, input_data, params_dict=None, parent_ctx=None): - assert(isinstance(params_dict, dict) or params_dict is None) - - pipeline_ctx = PipelineContext(d=params_dict if params_dict is not None else dict(), - parent_ctx=parent_ctx) + def run(self, pipeline_ctx): + assert(isinstance(pipeline_ctx, PipelineContext)) for item in filter(lambda itm: itm is not None, self.__pipeline): assert(isinstance(item, BasePipelineItem)) - input_data = item.apply(input_data=input_data, pipeline_ctx=pipeline_ctx) + item_result = item.apply(input_data=item.get_source(pipeline_ctx), pipeline_ctx=pipeline_ctx) + pipeline_ctx.update(param=item.ResultKey, value=item_result, is_new_key=False) - return input_data + return pipeline_ctx def append(self, item): assert(isinstance(item, BasePipelineItem)) diff --git a/arekit/common/pipeline/context.py b/arekit/common/pipeline/context.py index eadbc943..26316f8f 100644 --- a/arekit/common/pipeline/context.py +++ b/arekit/common/pipeline/context.py @@ -13,6 +13,8 @@ def __init__(self, d, parent_ctx=None): self._d[PARENT_CTX] = parent_ctx def __provide(self, param): + if param not in self._d: + raise Exception(f"Key `{param}` is not in dictionary.\n{self._d}") return self._d[param] # region public @@ -23,7 +25,9 @@ def provide(self, param): def provide_or_none(self, param): return self.__provide(param) if param in self._d else None - def update(self, param, value): + def update(self, param, value, is_new_key=False): + if is_new_key and param in self._d: + raise Exception(f"Key `{param}` is already presented in pipeline context dictionary.") self._d[param] = value # endregion diff --git a/arekit/common/pipeline/items/base.py b/arekit/common/pipeline/items/base.py index f0e26045..207cecf3 100644 --- a/arekit/common/pipeline/items/base.py +++ b/arekit/common/pipeline/items/base.py @@ -1,7 +1,37 @@ +from arekit.common.pipeline.context import PipelineContext + + class BasePipelineItem(object): """ Single pipeline item that might be instatiated and embedded into pipeline. """ + def __init__(self, src_key="result", result_key="result", src_func=None): + assert(isinstance(src_key, str) or src_key is None) + assert(callable(src_func) or src_func is None) + self.__src_key = src_key + self.__src_func = src_func + self.__result_key = result_key + + @property + def ResultKey(self): + return self.__result_key + + def get_source(self, src_ctx): + """ Extract input element for processing. + """ + assert(isinstance(src_ctx, PipelineContext)) + + # If there is no information about key, then we consider absence of the source. + if self.__src_key is None: + return None + + # Extracting actual source. + src_data = src_ctx.provide(self.__src_key) + if self.__src_func is not None: + src_data = self.__src_func(src_data) + + return src_data + def apply_core(self, input_data, pipeline_ctx): raise NotImplementedError() diff --git a/arekit/common/pipeline/items/flatten.py b/arekit/common/pipeline/items/flatten.py index 423e1019..fcdd5a0b 100644 --- a/arekit/common/pipeline/items/flatten.py +++ b/arekit/common/pipeline/items/flatten.py @@ -5,10 +5,14 @@ class FlattenIterPipelineItem(BasePipelineItem): """ Considered to flat iterations of items that represent iterations. """ + def __init__(self, **kwargs): + super(FlattenIterPipelineItem, self).__init__(**kwargs) + pass + def __flat_iter(self, iter_data): for iter_item in iter_data: for item in iter_item: yield item def apply_core(self, input_data, pipeline_ctx): - return self.__flat_iter(input_data) \ No newline at end of file + return self.__flat_iter(input_data) diff --git a/arekit/common/pipeline/items/handle.py b/arekit/common/pipeline/items/handle.py index 721690a4..024d5cd5 100644 --- a/arekit/common/pipeline/items/handle.py +++ b/arekit/common/pipeline/items/handle.py @@ -3,8 +3,9 @@ class HandleIterPipelineItem(BasePipelineItem): - def __init__(self, handle_func=None): + def __init__(self, handle_func=None, **kwargs): assert(callable(handle_func)) + super(HandleIterPipelineItem, self).__init__(**kwargs) self.__handle_func = handle_func def __updated_data(self, items_iter): diff --git a/arekit/common/pipeline/items/iter.py b/arekit/common/pipeline/items/iter.py index d7687a96..93d01f83 100644 --- a/arekit/common/pipeline/items/iter.py +++ b/arekit/common/pipeline/items/iter.py @@ -3,8 +3,9 @@ class FilterPipelineItem(BasePipelineItem): - def __init__(self, filter_func=None): + def __init__(self, filter_func=None, **kwargs): assert(callable(filter_func)) + super(FilterPipelineItem, self).__init__(**kwargs) self.__filter_func = filter_func def apply_core(self, input_data, pipeline_ctx): diff --git a/arekit/common/pipeline/items/map.py b/arekit/common/pipeline/items/map.py index d669645c..5ffb3fde 100644 --- a/arekit/common/pipeline/items/map.py +++ b/arekit/common/pipeline/items/map.py @@ -3,8 +3,9 @@ class MapPipelineItem(BasePipelineItem): - def __init__(self, map_func=None): + def __init__(self, map_func=None, **kwargs): assert(callable(map_func)) + super(MapPipelineItem, self).__init__(**kwargs) self._map_func = map_func def apply_core(self, input_data, pipeline_ctx): diff --git a/arekit/common/pipeline/items/map_nested.py b/arekit/common/pipeline/items/map_nested.py index 6e2c48e0..a0b008e1 100644 --- a/arekit/common/pipeline/items/map_nested.py +++ b/arekit/common/pipeline/items/map_nested.py @@ -9,5 +9,9 @@ class MapNestedPipelineItem(MapPipelineItem): suppose to be mapped with the passed pipeline context. """ + def __init__(self, **kwargs): + super(MapNestedPipelineItem, self).__init__(**kwargs) + pass + def apply_core(self, input_data, pipeline_ctx): return map(lambda item: self._map_func(item, pipeline_ctx), input_data) diff --git a/arekit/common/text/parser.py b/arekit/common/text/parser.py index b0f9d593..1fb5ff2f 100644 --- a/arekit/common/text/parser.py +++ b/arekit/common/text/parser.py @@ -1,12 +1,11 @@ from arekit.common.pipeline.base import BasePipeline +from arekit.common.pipeline.context import PipelineContext from arekit.common.text.parsed import BaseParsedText class BaseTextParser(BasePipeline): - def run(self, input_data, params_dict=None, parent_ctx=None): - output_data = super(BaseTextParser, self).run(input_data=input_data, - params_dict=params_dict, - parent_ctx=parent_ctx) - - return BaseParsedText(terms=output_data) + def run(self, params_dict, parent_ctx=None): + assert(isinstance(params_dict, dict)) + ctx = super(BaseTextParser, self).run(pipeline_ctx=PipelineContext(params_dict, parent_ctx=parent_ctx)) + return BaseParsedText(terms=ctx.provide("result")) diff --git a/arekit/contrib/source/brat/entities/parser.py b/arekit/contrib/source/brat/entities/parser.py index 35ab5e0e..5dfb9d13 100644 --- a/arekit/contrib/source/brat/entities/parser.py +++ b/arekit/contrib/source/brat/entities/parser.py @@ -1,5 +1,4 @@ from arekit.common.docs.objects_parser import SentenceObjectsParserPipelineItem -from arekit.common.pipeline.context import PipelineContext from arekit.common.text.partitioning.str import StringPartitioning from arekit.common.text.partitioning.terms import TermsPartitioning from arekit.contrib.source.brat.sentence import BratSentence @@ -7,8 +6,6 @@ class BratTextEntitiesParser(SentenceObjectsParserPipelineItem): - KEY = "sentence" - ################################ # NOTE: Supported partitionings. ################################ @@ -22,29 +19,22 @@ class BratTextEntitiesParser(SentenceObjectsParserPipelineItem): "terms": TermsPartitioning() } - def __init__(self, partitioning="string"): + def __init__(self, partitioning="string", **kwargs): assert(isinstance(partitioning, str)) - super(BratTextEntitiesParser, self).__init__(self.__supported_partitionings[partitioning]) + super(BratTextEntitiesParser, self).__init__(self.__supported_partitionings[partitioning], **kwargs) # region protected methods - def _get_text(self, pipeline_ctx): - sentence = self.__get_sentence(pipeline_ctx) + def _get_text(self, sentence): return sentence.Text - def _get_parts_provider_func(self, input_data, pipeline_ctx): - sentence = self.__get_sentence(pipeline_ctx) + def _get_parts_provider_func(self, sentence): return self.__iter_subs_values_with_bounds(sentence) # endregion # region private methods - def __get_sentence(self, pipeline_ctx): - assert(isinstance(pipeline_ctx, PipelineContext)) - assert(self.KEY in pipeline_ctx) - return pipeline_ctx.provide(self.KEY) - @staticmethod def __iter_subs_values_with_bounds(sentence): assert(isinstance(sentence, BratSentence)) diff --git a/arekit/contrib/source/ruattitudes/entity/parser.py b/arekit/contrib/source/ruattitudes/entity/parser.py index dc09b78e..9d5c8343 100644 --- a/arekit/contrib/source/ruattitudes/entity/parser.py +++ b/arekit/contrib/source/ruattitudes/entity/parser.py @@ -3,5 +3,5 @@ class RuAttitudesTextEntitiesParser(BratTextEntitiesParser): - def __init__(self): - super(RuAttitudesTextEntitiesParser, self).__init__(partitioning="terms") + def __init__(self, **kwargs): + super(RuAttitudesTextEntitiesParser, self).__init__(partitioning="terms", **kwargs) diff --git a/arekit/contrib/utils/data/contents/opinions.py b/arekit/contrib/utils/data/contents/opinions.py index e3032d4a..0966ff62 100644 --- a/arekit/contrib/utils/data/contents/opinions.py +++ b/arekit/contrib/utils/data/contents/opinions.py @@ -3,6 +3,7 @@ from arekit.common.linkage.base import LinkedDataWrapper from arekit.common.linkage.text_opinions import TextOpinionsLinkage from arekit.common.pipeline.base import BasePipeline +from arekit.common.pipeline.context import PipelineContext from arekit.common.text_opinions.base import TextOpinion @@ -30,7 +31,16 @@ def __assign_ids(self, linkage): def from_doc_ids(self, doc_ids, idle_mode=False): self.__current_id = 0 - for linkage in self.__pipeline.run(doc_ids, params_dict={IDLE_MODE: idle_mode}): + + ctx = PipelineContext(d={ + "result": doc_ids, + IDLE_MODE: idle_mode + }) + + # Launching pipeline with the passed context + self.__pipeline.run(ctx) + + for linkage in ctx.provide("result"): assert(isinstance(linkage, LinkedDataWrapper)) if isinstance(linkage, TextOpinionsLinkage): self.__assign_ids(linkage) diff --git a/arekit/contrib/utils/pipelines/items/sampling/base.py b/arekit/contrib/utils/pipelines/items/sampling/base.py index db656cfa..33cb11da 100644 --- a/arekit/contrib/utils/pipelines/items/sampling/base.py +++ b/arekit/contrib/utils/pipelines/items/sampling/base.py @@ -10,7 +10,7 @@ class BaseSerializerPipelineItem(BasePipelineItem): - def __init__(self, rows_provider, samples_io, save_labels_func, storage): + def __init__(self, rows_provider, samples_io, save_labels_func, storage, **kwargs): """ sample_rows_formatter: how we format input texts for a BERT model, for example: - single text @@ -23,6 +23,7 @@ def __init__(self, rows_provider, samples_io, save_labels_func, storage): assert(isinstance(samples_io, BaseSamplesIO)) assert(callable(save_labels_func)) assert(isinstance(storage, BaseRowsStorage)) + super(BaseSerializerPipelineItem, self).__init__(**kwargs) self._rows_provider = rows_provider self._samples_io = samples_io @@ -89,11 +90,7 @@ def apply_core(self, input_data, pipeline_ctx): doc_ids: optional this parameter allows to limit amount of documents considered for sampling """ - assert(isinstance(input_data, PipelineContext)) - assert("data_type_pipelines" in input_data) - - data_folding = input_data.provide_or_none("data_folding") - - self._handle_iteration(data_type_pipelines=input_data.provide("data_type_pipelines"), - doc_ids=input_data.provide_or_none("doc_ids"), - data_folding=data_folding) + assert("data_type_pipelines" in pipeline_ctx) + self._handle_iteration(data_type_pipelines=pipeline_ctx.provide("data_type_pipelines"), + doc_ids=pipeline_ctx.provide_or_none("doc_ids"), + data_folding=pipeline_ctx.provide_or_none("data_folding")) diff --git a/arekit/contrib/utils/pipelines/items/sampling/networks.py b/arekit/contrib/utils/pipelines/items/sampling/networks.py index 38cf152c..cfcd1650 100644 --- a/arekit/contrib/utils/pipelines/items/sampling/networks.py +++ b/arekit/contrib/utils/pipelines/items/sampling/networks.py @@ -8,7 +8,7 @@ class NetworksInputSerializerPipelineItem(BaseSerializerPipelineItem): - def __init__(self, save_labels_func, rows_provider, samples_io, emb_io, storage, save_embedding=True): + def __init__(self, save_labels_func, rows_provider, samples_io, emb_io, storage, save_embedding=True, **kwargs): """ This pipeline item allows to perform a data preparation for neural network models. considering a list of the whole data_types with the related pipelines, @@ -23,7 +23,8 @@ def __init__(self, save_labels_func, rows_provider, samples_io, emb_io, storage, rows_provider=rows_provider, samples_io=samples_io, save_labels_func=save_labels_func, - storage=storage) + storage=storage, + **kwargs) self.__emb_io = emb_io self.__save_embedding = save_embedding diff --git a/arekit/contrib/utils/pipelines/items/text/entities_default.py b/arekit/contrib/utils/pipelines/items/text/entities_default.py index 1490506c..65ab20ea 100644 --- a/arekit/contrib/utils/pipelines/items/text/entities_default.py +++ b/arekit/contrib/utils/pipelines/items/text/entities_default.py @@ -4,8 +4,8 @@ class TextEntitiesParser(BasePipelineItem): - def __init__(self): - super(TextEntitiesParser, self).__init__() + def __init__(self, **kwargs): + super(TextEntitiesParser, self).__init__(**kwargs) @staticmethod def __process_word(word): diff --git a/arekit/contrib/utils/pipelines/items/text/frames.py b/arekit/contrib/utils/pipelines/items/text/frames.py index db4cf8ad..7b1ce2f7 100644 --- a/arekit/contrib/utils/pipelines/items/text/frames.py +++ b/arekit/contrib/utils/pipelines/items/text/frames.py @@ -6,11 +6,10 @@ class FrameVariantsParser(BasePipelineItem): - def __init__(self, frame_variants): + def __init__(self, frame_variants, **kwargs): assert(isinstance(frame_variants, FrameVariantsCollection)) assert(len(frame_variants) > 0) - - super(FrameVariantsParser, self).__init__() + super(FrameVariantsParser, self).__init__(**kwargs) self.__frame_variants = frame_variants self.__max_variant_len = max([len(variant) for _, variant in frame_variants.iter_variants()]) diff --git a/arekit/contrib/utils/pipelines/items/text/frames_lemmatized.py b/arekit/contrib/utils/pipelines/items/text/frames_lemmatized.py index 75a8412e..9b87ba0c 100644 --- a/arekit/contrib/utils/pipelines/items/text/frames_lemmatized.py +++ b/arekit/contrib/utils/pipelines/items/text/frames_lemmatized.py @@ -5,10 +5,10 @@ class LemmasBasedFrameVariantsParser(FrameVariantsParser): - def __init__(self, frame_variants, stemmer, locale_mods=RussianLanguageMods, save_lemmas=False): + def __init__(self, frame_variants, stemmer, locale_mods=RussianLanguageMods, save_lemmas=False, **kwargs): assert(isinstance(stemmer, Stemmer)) assert(isinstance(save_lemmas, bool)) - super(LemmasBasedFrameVariantsParser, self).__init__(frame_variants=frame_variants) + super(LemmasBasedFrameVariantsParser, self).__init__(frame_variants=frame_variants, **kwargs) self.__frame_variants = frame_variants self.__stemmer = stemmer diff --git a/arekit/contrib/utils/pipelines/items/text/frames_negation.py b/arekit/contrib/utils/pipelines/items/text/frames_negation.py index 6b9421a4..ff7d6181 100644 --- a/arekit/contrib/utils/pipelines/items/text/frames_negation.py +++ b/arekit/contrib/utils/pipelines/items/text/frames_negation.py @@ -7,8 +7,9 @@ class FrameVariantsSentimentNegation(BasePipelineItem): - def __init__(self, locale_mods=RussianLanguageMods): + def __init__(self, locale_mods=RussianLanguageMods, **kwargs): assert(issubclass(locale_mods, BaseLanguageMods)) + super(FrameVariantsSentimentNegation, self).__init__(**kwargs) self._locale_mods = locale_mods @staticmethod diff --git a/arekit/contrib/utils/pipelines/items/text/terms_splitter.py b/arekit/contrib/utils/pipelines/items/text/terms_splitter.py index 542b3381..19df29f5 100644 --- a/arekit/contrib/utils/pipelines/items/text/terms_splitter.py +++ b/arekit/contrib/utils/pipelines/items/text/terms_splitter.py @@ -5,6 +5,7 @@ class TermsSplitterParser(BasePipelineItem): - def apply_core(self, input_data, pipeline_ctx): + def apply_core(self, input_data, pipeline_ctx, **kwargs): assert(isinstance(pipeline_ctx, PipelineContext)) + super(TermsSplitterParser, self).apply_core(**kwargs) return split_by_whitespaces(input_data) diff --git a/arekit/contrib/utils/pipelines/items/text/tokenizer.py b/arekit/contrib/utils/pipelines/items/text/tokenizer.py index d5b9d3ed..6d98ffaf 100644 --- a/arekit/contrib/utils/pipelines/items/text/tokenizer.py +++ b/arekit/contrib/utils/pipelines/items/text/tokenizer.py @@ -14,14 +14,13 @@ class DefaultTextTokenizer(BasePipelineItem): """ Default parser implementation. """ - def __init__(self, keep_tokens=True): - super(DefaultTextTokenizer, self).__init__() + def __init__(self, keep_tokens=True, **kwargs): + super(DefaultTextTokenizer, self).__init__(**kwargs) self.__keep_tokens = keep_tokens # region protected methods def apply_core(self, input_data, pipeline_ctx): - assert(isinstance(pipeline_ctx, PipelineContext)) output_data = self.__process_parts(input_data) if not self.__keep_tokens: output_data = [word for word in output_data if not isinstance(word, Token)] diff --git a/arekit/contrib/utils/pipelines/items/text/translator.py b/arekit/contrib/utils/pipelines/items/text/translator.py index ae74ec28..3bb5478f 100644 --- a/arekit/contrib/utils/pipelines/items/text/translator.py +++ b/arekit/contrib/utils/pipelines/items/text/translator.py @@ -9,10 +9,11 @@ class MLTextTranslatorPipelineItem(BasePipelineItem): """ Machine learning based translator pipeline item. """ - def __init__(self, batch_translate_model, do_translate_entity=True): + def __init__(self, batch_translate_model, do_translate_entity=True, **kwargs): """ Model, which is based on translation of the text, represented as a list of words. """ + super(MLTextTranslatorPipelineItem, self).__init__(**kwargs) self.__do_translate_entity = do_translate_entity self.__translate = batch_translate_model diff --git a/tests/README.md b/tests/README.md index e92181ba..e2fc9786 100644 --- a/tests/README.md +++ b/tests/README.md @@ -12,6 +12,8 @@ pip install -e ../ --no-deps ``` Using `pytest` to run all the test and gather report into `pytest_report.html` document. + +**NOTE: Launch from the root project folder** ```bash python -m pytest --html=pytest_report.html --self-contained-html --continue-on-collection-errors . ``` diff --git a/tests/contrib/networks/doc.py b/tests/contrib/networks/doc.py index 5380c385..f987c900 100644 --- a/tests/contrib/networks/doc.py +++ b/tests/contrib/networks/doc.py @@ -5,7 +5,7 @@ from arekit.contrib.source.rusentrel.labels_fmt import RuSentRelLabelsFormatter from arekit.contrib.source.rusentrel.docs_reader import RuSentRelDocumentsReader from arekit.contrib.source.rusentrel.opinions.collection import RuSentRelOpinions -from tests.contrib.networks.labels import TestPositiveLabel, TestNegativeLabel +from labels import TestPositiveLabel, TestNegativeLabel def init_rusentrel_doc(doc_id, text_parser, synonyms): diff --git a/tests/contrib/networks/indices_feature.py b/tests/contrib/networks/indices_feature.py index 6097feaf..30d02341 100644 --- a/tests/contrib/networks/indices_feature.py +++ b/tests/contrib/networks/indices_feature.py @@ -1,6 +1,6 @@ import numpy as np -from tests.contrib.networks import utils +import utils class IndicesFeature: diff --git a/tests/contrib/networks/test_input_features.py b/tests/contrib/networks/test_input_features.py index 589ac38e..0a28703a 100644 --- a/tests/contrib/networks/test_input_features.py +++ b/tests/contrib/networks/test_input_features.py @@ -5,10 +5,9 @@ import numpy as np + sys.path.append('../') -from tests.text.linked_opinions import iter_same_sentence_linked_text_opinions -from tests.text.utils import terms_to_str from arekit.common.text.stemmer import Stemmer from arekit.common.docs.parsed.providers.entity_service import EntityServiceProvider from arekit.common.docs.parsed.providers.text_opinion_pairs import TextOpinionPairsProvider @@ -29,9 +28,10 @@ from arekit.contrib.utils.pipelines.items.text.tokenizer import DefaultTextTokenizer from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper from arekit.contrib.utils.entities.formatters.str_display import StringEntitiesDisplayValueFormatter -from tests.contrib.networks.doc import init_rusentrel_doc -from tests.contrib.networks.indices_feature import IndicesFeature -from tests.contrib.networks.labels import TestPositiveLabel, TestNegativeLabel +from doc import init_rusentrel_doc +from indices_feature import IndicesFeature +from labels import TestPositiveLabel, TestNegativeLabel +from utils import iter_same_sentence_linked_text_opinions, terms_to_str class RuSentRelSynonymsCollectionProvider(object): @@ -72,7 +72,7 @@ def test(self): logger.setLevel(logging.INFO) logging.basicConfig(level=logging.DEBUG) - text_parser = BaseTextParser(pipeline=[BratTextEntitiesParser(), + text_parser = BaseTextParser(pipeline=[BratTextEntitiesParser(src_key="input"), DefaultTextTokenizer(keep_tokens=True), LemmasBasedFrameVariantsParser( frame_variants=self.unique_frame_variants, diff --git a/tests/contrib/networks/utils.py b/tests/contrib/networks/utils.py index 9e1937a0..90761409 100644 --- a/tests/contrib/networks/utils.py +++ b/tests/contrib/networks/utils.py @@ -1,3 +1,15 @@ +from arekit.common.context.token import Token +from arekit.common.docs.parsed.providers.entity_service import EntityServiceProvider +from arekit.common.docs.parsed.providers.text_opinion_pairs import TextOpinionPairsProvider +from arekit.common.docs.parsed.term_position import TermPositionTypes +from arekit.common.entities.base import Entity +from arekit.common.frames.text_variant import TextFrameVariant +from arekit.common.frames.variants.base import FrameVariant +from arekit.common.linkage.text_opinions import TextOpinionsLinkage +from arekit.common.opinions.collection import OpinionCollection +from arekit.common.text_opinions.base import TextOpinion + + def pad_right_inplace(lst, pad_size, filler): assert(pad_size - len(lst) > 0) @@ -6,3 +18,59 @@ def pad_right_inplace(lst, pad_size, filler): def in_window(window_begin, window_end, ind): return window_begin <= ind < window_end + + +def __is_same_sentence(text_opinion, entity_service): + assert(isinstance(text_opinion, TextOpinion)) + + s_ind = entity_service.get_entity_position(id_in_document=text_opinion.SourceId, + position_type=TermPositionTypes.SentenceIndex) + t_ind = entity_service.get_entity_position(id_in_document=text_opinion.TargetId, + position_type=TermPositionTypes.SentenceIndex) + return s_ind == t_ind + + +def iter_same_sentence_linked_text_opinions(pairs_provider, entity_service, opinions): + assert(isinstance(pairs_provider, TextOpinionPairsProvider)) + assert(isinstance(entity_service, EntityServiceProvider)) + assert(isinstance(opinions, OpinionCollection)) + + for opinion in opinions: + + text_opinions_linkage = TextOpinionsLinkage( + linked_data=pairs_provider.iter_from_opinion(opinion)) + + assert(isinstance(text_opinions_linkage, TextOpinionsLinkage)) + + if len(text_opinions_linkage) == 0: + continue + + text_opinion = text_opinions_linkage.First + assert(isinstance(text_opinion, TextOpinion)) + assert(isinstance(text_opinions_linkage, TextOpinionsLinkage)) + + is_same = __is_same_sentence(text_opinion=text_opinion, entity_service=entity_service) + + if not is_same: + continue + + yield text_opinion + + +def terms_to_str(terms): + r = [] + for t in terms: + if isinstance(t, str): + r.append(t) + elif isinstance(t, Token): + r.append(t.get_token_value()) + elif isinstance(t, Entity): + r.append("[{}]".format(t.Value)) + elif isinstance(t, TextFrameVariant): + r.append("<{}>".format(t.Variant.get_value())) + elif isinstance(t, FrameVariant): + r.append(t.get_value()) + else: + r.append(t) + + return r diff --git a/tests/contrib/source/doc.py b/tests/contrib/source/doc.py index 5b6bcd21..92de209c 100644 --- a/tests/contrib/source/doc.py +++ b/tests/contrib/source/doc.py @@ -7,7 +7,7 @@ from arekit.contrib.source.rusentrel.labels_fmt import RuSentRelLabelsFormatter from arekit.contrib.source.rusentrel.docs_reader import RuSentRelDocumentsReader from arekit.contrib.source.rusentrel.opinions.collection import RuSentRelOpinions -from tests.contrib.source.labels import PositiveLabel, NegativeLabel +from labels import PositiveLabel, NegativeLabel def init_rusentrel_doc(doc_id, text_parser, synonyms): diff --git a/tests/contrib/source/test_labels.py b/tests/contrib/source/test_labels.py index 5b5081be..78894fa4 100644 --- a/tests/contrib/source/test_labels.py +++ b/tests/contrib/source/test_labels.py @@ -4,7 +4,7 @@ sys.path.append('../') from arekit.common.labels.base import Label, NoLabel -from tests.contrib.source.labels import NegativeLabel, PositiveLabel +from labels import NegativeLabel, PositiveLabel class TestLabels(unittest.TestCase): diff --git a/tests/contrib/source/test_ruattitudes.py b/tests/contrib/source/test_ruattitudes.py index 0a3590f9..ef07f12d 100755 --- a/tests/contrib/source/test_ruattitudes.py +++ b/tests/contrib/source/test_ruattitudes.py @@ -26,8 +26,8 @@ from arekit.contrib.source.brat.entities.entity import BratEntity from arekit.contrib.source.ruattitudes.doc_brat import RuAttitudesDocumentsConverter -from tests.contrib.source.utils import RuAttitudesSentenceOpinionUtils -from tests.contrib.source.labels import PositiveLabel, NegativeLabel +from utils import RuAttitudesSentenceOpinionUtils +from labels import PositiveLabel, NegativeLabel logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) @@ -99,15 +99,15 @@ def __iter_indices(self, ra_version): def __test_parsing(self, ra_version): # Initialize text parser pipeline. - text_parser = BaseTextParser(pipeline=[RuAttitudesTextEntitiesParser(), + text_parser = BaseTextParser(pipeline=[RuAttitudesTextEntitiesParser(src_key="input"), DefaultTextTokenizer(keep_tokens=True)]) # iterating through collection doc_read = 0 doc_it = RuAttitudesCollection.iter_docs(version=ra_version, - get_doc_index_func=lambda _: doc_read, - return_inds_only=False) + get_doc_index_func=lambda _: doc_read, + return_inds_only=False) for doc in tqdm(doc_it): @@ -148,8 +148,8 @@ def __test_reading(self, ra_version, do_printing=True): # iterating through collection doc_read = 0 doc_it = RuAttitudesCollection.iter_docs(version=ra_version, - get_doc_index_func=lambda _: doc_read, - return_inds_only=False) + get_doc_index_func=lambda _: doc_read, + return_inds_only=False) if not do_printing: doc_it = tqdm(doc_it) diff --git a/tests/contrib/source/test_rusentiframes.py b/tests/contrib/source/test_rusentiframes.py index ca97fa4c..7b8a8085 100755 --- a/tests/contrib/source/test_rusentiframes.py +++ b/tests/contrib/source/test_rusentiframes.py @@ -11,7 +11,7 @@ from arekit.contrib.source.rusentiframes.collection import RuSentiFramesCollection from arekit.contrib.source.rusentiframes.types import RuSentiFramesVersions -from tests.contrib.source.labels import PositiveLabel, NegativeLabel +from labels import PositiveLabel, NegativeLabel class TestRuSentiFrames(unittest.TestCase): diff --git a/tests/contrib/source/test_rusentiframes_stat.py b/tests/contrib/source/test_rusentiframes_stat.py index fefe6b41..7d63caa6 100644 --- a/tests/contrib/source/test_rusentiframes_stat.py +++ b/tests/contrib/source/test_rusentiframes_stat.py @@ -11,7 +11,7 @@ from arekit.contrib.source.rusentiframes.effect import FrameEffect from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper from arekit.contrib.utils.processing.pos.mystem_wrap import POSMystemWrapper -from tests.contrib.source.labels import PositiveLabel, NegativeLabel +from labels import PositiveLabel, NegativeLabel def __iter_unique_frame_variants(frames_collection, frame_ids): diff --git a/tests/contrib/source/test_show_frames_stat.py b/tests/contrib/source/test_show_frames_stat.py index f38347e3..4fbae011 100644 --- a/tests/contrib/source/test_show_frames_stat.py +++ b/tests/contrib/source/test_show_frames_stat.py @@ -5,7 +5,7 @@ from arekit.contrib.source.rusentiframes.types import RuSentiFramesVersions -from tests.contrib.source.test_rusentiframes_stat import about_version +from test_rusentiframes_stat import about_version class TestFramesStat(unittest.TestCase): diff --git a/tests/contrib/utils/test_csv_stream_write.py b/tests/contrib/utils/test_csv_stream_write.py index e07ea8b3..9a604c4d 100644 --- a/tests/contrib/utils/test_csv_stream_write.py +++ b/tests/contrib/utils/test_csv_stream_write.py @@ -81,7 +81,7 @@ def __launch(self, writer): text_parser=text_parser) ##### - pipeline.run(input_data=PipelineContext(d={ + pipeline.run(pipeline_ctx=PipelineContext(d={ "data_type_pipelines": {DataType.Train: train_pipeline}, "data_folding": {DataType.Train: [0, 1]} })) diff --git a/tests/contrib/utils/test_frames_annotation.py b/tests/contrib/utils/test_frames_annotation.py index a3be88b1..bb114f7f 100644 --- a/tests/contrib/utils/test_frames_annotation.py +++ b/tests/contrib/utils/test_frames_annotation.py @@ -8,8 +8,8 @@ from arekit.contrib.source.rusentiframes.types import RuSentiFramesVersions from arekit.contrib.utils.pipelines.items.text.frames_lemmatized import LemmasBasedFrameVariantsParser from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper -from tests.contrib.source.labels import PositiveLabel -from tests.contrib.utils.labels import NegativeLabel +from labels import PositiveLabel +from labels import NegativeLabel class TestFramesAnnotation(unittest.TestCase): diff --git a/tests/contrib/utils/test_text_parser.py b/tests/contrib/utils/test_text_parser.py index 6d457700..c0c5f195 100644 --- a/tests/contrib/utils/test_text_parser.py +++ b/tests/contrib/utils/test_text_parser.py @@ -22,9 +22,9 @@ from arekit.contrib.utils.pipelines.items.text.tokenizer import DefaultTextTokenizer from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper from arekit.contrib.utils.synonyms.stemmer_based import StemmerBasedSynonymCollection -from tests.contrib.utils.labels import NegativeLabel, PositiveLabel +from labels import NegativeLabel, PositiveLabel -from tests.contrib.utils.text.debug_text import debug_show_doc_terms +from text.debug_text import debug_show_doc_terms class RuSentRelSynonymsCollectionProvider(object): @@ -42,7 +42,9 @@ class TestTextParser(unittest.TestCase): def test_parse_single_string(self): text = "А контроль над этими провинциями — это господство над без малого половиной сирийской территории." - parser = BaseTextParser(pipeline=[DefaultTextTokenizer(keep_tokens=True)]) + parser = BaseTextParser(pipeline=[ + DefaultTextTokenizer(keep_tokens=True, src_key="input", src_func=lambda s: s.Text) + ]) doc = Document(doc_id=0, sentences=[BaseDocumentSentence(text.split())]) parsed_doc = DocumentParser.parse(doc=doc, text_parser=parser) debug_show_doc_terms(parsed_doc=parsed_doc) @@ -66,11 +68,13 @@ def test_parse_frame_variants(self): overwrite_existed_variant=True, raise_error_on_existed_variant=False) - parser = BaseTextParser(pipeline=[DefaultTextTokenizer(keep_tokens=True), - FrameVariantsParser(frame_variants=frame_variants), - LemmasBasedFrameVariantsParser(frame_variants=frame_variants, - stemmer=stemmer), - FrameVariantsSentimentNegation()]) + parser = BaseTextParser(pipeline=[ + DefaultTextTokenizer(keep_tokens=True, src_key="input", src_func=lambda s: s.Text), + FrameVariantsParser(frame_variants=frame_variants), + LemmasBasedFrameVariantsParser(frame_variants=frame_variants, + stemmer=stemmer), + FrameVariantsSentimentNegation()]) + doc = Document(doc_id=0, sentences=[BaseDocumentSentence(text.split())]) parsed_doc = DocumentParser.parse(doc=doc, text_parser=parser) debug_show_doc_terms(parsed_doc=parsed_doc) @@ -98,7 +102,7 @@ def test_parsing(self): overwrite_existed_variant=True, raise_error_on_existed_variant=False) - text_parser = BaseTextParser(pipeline=[BratTextEntitiesParser(), + text_parser = BaseTextParser(pipeline=[BratTextEntitiesParser(src_key="input"), DefaultTextTokenizer(keep_tokens=True), LemmasBasedFrameVariantsParser(frame_variants=frame_variants, stemmer=stemmer, diff --git a/tests/text/linked_opinions.py b/tests/text/linked_opinions.py deleted file mode 100644 index b6186735..00000000 --- a/tests/text/linked_opinions.py +++ /dev/null @@ -1,43 +0,0 @@ -from arekit.common.linkage.text_opinions import TextOpinionsLinkage -from arekit.common.docs.parsed.providers.entity_service import EntityServiceProvider -from arekit.common.docs.parsed.providers.text_opinion_pairs import TextOpinionPairsProvider -from arekit.common.docs.parsed.term_position import TermPositionTypes -from arekit.common.opinions.collection import OpinionCollection -from arekit.common.text_opinions.base import TextOpinion - - -def __is_same_sentence(text_opinion, entity_service): - assert(isinstance(text_opinion, TextOpinion)) - - s_ind = entity_service.get_entity_position(id_in_document=text_opinion.SourceId, - position_type=TermPositionTypes.SentenceIndex) - t_ind = entity_service.get_entity_position(id_in_document=text_opinion.TargetId, - position_type=TermPositionTypes.SentenceIndex) - return s_ind == t_ind - - -def iter_same_sentence_linked_text_opinions(pairs_provider, entity_service, opinions): - assert(isinstance(pairs_provider, TextOpinionPairsProvider)) - assert(isinstance(entity_service, EntityServiceProvider)) - assert(isinstance(opinions, OpinionCollection)) - - for opinion in opinions: - - text_opinions_linkage = TextOpinionsLinkage( - linked_data=pairs_provider.iter_from_opinion(opinion)) - - assert(isinstance(text_opinions_linkage, TextOpinionsLinkage)) - - if len(text_opinions_linkage) == 0: - continue - - text_opinion = text_opinions_linkage.First - assert(isinstance(text_opinion, TextOpinion)) - assert(isinstance(text_opinions_linkage, TextOpinionsLinkage)) - - is_same = __is_same_sentence(text_opinion=text_opinion, entity_service=entity_service) - - if not is_same: - continue - - yield text_opinion diff --git a/tests/text/test_nested_entities.py b/tests/text/test_nested_entities.py index 07215655..ecc34602 100644 --- a/tests/text/test_nested_entities.py +++ b/tests/text/test_nested_entities.py @@ -16,11 +16,9 @@ def test(self): tep = TextEntitiesParser() - text_parser = BaseTextParser(pipeline=[ - TextEntitiesParser(), - ]) + text_parser = BaseTextParser(pipeline=[TextEntitiesParser()]) - parsed_text = text_parser.run(s.split()) + parsed_text = text_parser.run({"result": s.split()}) assert(isinstance(parsed_text, BaseParsedText)) print(parsed_text._terms) diff --git a/tests/text/utils.py b/tests/text/utils.py index 89852449..d3570992 100644 --- a/tests/text/utils.py +++ b/tests/text/utils.py @@ -3,21 +3,3 @@ from arekit.common.frames.text_variant import TextFrameVariant from arekit.common.frames.variants.base import FrameVariant - -def terms_to_str(terms): - r = [] - for t in terms: - if isinstance(t, str): - r.append(t) - elif isinstance(t, Token): - r.append(t.get_token_value()) - elif isinstance(t, Entity): - r.append("[{}]".format(t.Value)) - elif isinstance(t, TextFrameVariant): - r.append("<{}>".format(t.Variant.get_value())) - elif isinstance(t, FrameVariant): - r.append(t.get_value()) - else: - r.append(t) - - return r diff --git a/tests/tutorials/test_tutorial_pipeline_sampling_bert.py b/tests/tutorials/test_tutorial_pipeline_sampling_bert.py index 23af4b23..88ae2032 100644 --- a/tests/tutorials/test_tutorial_pipeline_sampling_bert.py +++ b/tests/tutorials/test_tutorial_pipeline_sampling_bert.py @@ -27,7 +27,7 @@ from arekit.contrib.utils.pipelines.text_opinion.annot.predefined import PredefinedTextOpinionAnnotator from arekit.contrib.utils.pipelines.text_opinion.extraction import text_opinion_extraction_pipeline from arekit.contrib.utils.pipelines.text_opinion.filters.distance_based import DistanceLimitedTextOpinionFilter -from tests.tutorials.test_tutorial_pipeline_text_opinion_annotation import FooDocumentProvider +from test_tutorial_pipeline_text_opinion_annotation import FooDocumentProvider class Positive(Label): @@ -95,7 +95,8 @@ def test(self): rows_provider=rows_provider, samples_io=samples_io, save_labels_func=lambda data_type: True, - storage=PandasBasedRowsStorage()) + storage=PandasBasedRowsStorage(), + src_key=None) pipeline = BasePipeline([ pipeline_item @@ -105,7 +106,10 @@ def test(self): # Declaring pipeline related context parameters. ##### doc_provider = FooDocumentProvider() - text_parser = BaseTextParser(pipeline=[BratTextEntitiesParser(), DefaultTextTokenizer(keep_tokens=True)]) + text_parser = BaseTextParser(pipeline=[ + BratTextEntitiesParser(src_key="input"), + DefaultTextTokenizer(keep_tokens=True) + ]) train_pipeline = text_opinion_extraction_pipeline( annotators=[ PredefinedTextOpinionAnnotator( @@ -121,7 +125,7 @@ def test(self): text_parser=text_parser) ##### - pipeline.run(input_data=PipelineContext(d={ + pipeline.run(pipeline_ctx=PipelineContext(d={ "data_type_pipelines": {DataType.Train: train_pipeline}, "data_folding": {DataType.Train: [0, 1]} })) diff --git a/tests/tutorials/test_tutorial_pipeline_sampling_network.py b/tests/tutorials/test_tutorial_pipeline_sampling_network.py index 1b1fcdf5..ee2e834a 100644 --- a/tests/tutorials/test_tutorial_pipeline_sampling_network.py +++ b/tests/tutorials/test_tutorial_pipeline_sampling_network.py @@ -32,7 +32,7 @@ from arekit.contrib.utils.pipelines.text_opinion.filters.distance_based import DistanceLimitedTextOpinionFilter from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper from arekit.contrib.utils.processing.pos.mystem_wrap import POSMystemWrapper -from tests.tutorials.test_tutorial_pipeline_text_opinion_annotation import FooDocumentProvider +from test_tutorial_pipeline_text_opinion_annotation import FooDocumentProvider class Positive(Label): @@ -96,7 +96,8 @@ def test(self): emb_io=NpEmbeddingIO(target_dir=self.__output_dir), rows_provider=rows_provider, save_labels_func=lambda data_type: data_type != DataType.Test, - storage=PandasBasedRowsStorage()) + storage=PandasBasedRowsStorage(), + src_key=None) pipeline = BasePipeline([ pipeline_item @@ -107,7 +108,7 @@ def test(self): ##### doc_provider = FooDocumentProvider() text_parser = BaseTextParser(pipeline=[ - BratTextEntitiesParser(), + BratTextEntitiesParser(src_key="input"), DefaultTextTokenizer(keep_tokens=True), LemmasBasedFrameVariantsParser(frame_variants=frame_variant_collection, stemmer=stemmer) ]) @@ -125,7 +126,7 @@ def test(self): text_parser=text_parser) ##### - pipeline.run(input_data=PipelineContext(d={ + pipeline.run(pipeline_ctx=PipelineContext(d={ "data_type_pipelines": {DataType.Train: train_pipeline}, "data_folding": {DataType.Train: [0, 1]} })) diff --git a/tests/tutorials/test_tutorial_pipeline_sampling_prompt.py b/tests/tutorials/test_tutorial_pipeline_sampling_prompt.py index 4f43a251..ba94aa46 100644 --- a/tests/tutorials/test_tutorial_pipeline_sampling_prompt.py +++ b/tests/tutorials/test_tutorial_pipeline_sampling_prompt.py @@ -26,7 +26,7 @@ from arekit.contrib.utils.pipelines.text_opinion.annot.predefined import PredefinedTextOpinionAnnotator from arekit.contrib.utils.pipelines.text_opinion.extraction import text_opinion_extraction_pipeline from arekit.contrib.utils.pipelines.text_opinion.filters.distance_based import DistanceLimitedTextOpinionFilter -from tests.tutorials.test_tutorial_pipeline_text_opinion_annotation import FooDocumentProvider +from test_tutorial_pipeline_text_opinion_annotation import FooDocumentProvider class Positive(Label): @@ -96,7 +96,8 @@ def test(self): rows_provider=rows_provider, samples_io=samples_io, save_labels_func=lambda data_type: True, - storage=PandasBasedRowsStorage()) + storage=PandasBasedRowsStorage(), + src_key=None) pipeline = BasePipeline([ pipeline_item @@ -106,7 +107,10 @@ def test(self): # Declaring pipeline related context parameters. ##### doc_provider = FooDocumentProvider() - text_parser = BaseTextParser(pipeline=[BratTextEntitiesParser(), DefaultTextTokenizer(keep_tokens=True)]) + text_parser = BaseTextParser(pipeline=[ + BratTextEntitiesParser(src_key="input"), + DefaultTextTokenizer(keep_tokens=True) + ]) train_pipeline = text_opinion_extraction_pipeline( annotators=[ PredefinedTextOpinionAnnotator(doc_provider, @@ -122,7 +126,7 @@ def test(self): text_parser=text_parser) ##### - pipeline.run(input_data=PipelineContext(d={ + pipeline.run(pipeline_ctx=PipelineContext(d={ "data_type_pipelines": {DataType.Train: train_pipeline}, "data_folding": {DataType.Train: [0, 1]} })) diff --git a/tests/tutorials/test_tutorial_pipeline_text_opinion_annotation.py b/tests/tutorials/test_tutorial_pipeline_text_opinion_annotation.py index 39ba4a1f..4cda58fd 100644 --- a/tests/tutorials/test_tutorial_pipeline_text_opinion_annotation.py +++ b/tests/tutorials/test_tutorial_pipeline_text_opinion_annotation.py @@ -10,6 +10,7 @@ from arekit.common.docs.parsed.service import ParsedDocumentService from arekit.common.opinions.annot.algo.pair_based import PairBasedOpinionAnnotationAlgorithm from arekit.common.opinions.collection import OpinionCollection +from arekit.common.pipeline.context import PipelineContext from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders from arekit.common.text.parser import BaseTextParser from arekit.contrib.source.brat.entities.parser import BratTextEntitiesParser @@ -20,7 +21,7 @@ from arekit.contrib.utils.pipelines.text_opinion.filters.distance_based import DistanceLimitedTextOpinionFilter from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper from arekit.contrib.utils.synonyms.stemmer_based import StemmerBasedSynonymCollection -from tests.tutorials.test_tutorial_collection_binding import FooDocReader +from test_tutorial_collection_binding import FooDocReader class PositiveLabel(Label): @@ -69,7 +70,7 @@ def test(self): synonyms=synonyms, value=value)) text_parser = BaseTextParser([ - BratTextEntitiesParser(partitioning="string"), + BratTextEntitiesParser(partitioning="string", src_key="input"), DefaultTextTokenizer(keep_tokens=True), ]) @@ -85,8 +86,16 @@ def test(self): entity_index_func=entity_index_func, text_parser=text_parser) + # Defining pipeline context. + context = PipelineContext( + d={"result": [0]} + ) + + # launching pipeline. + pipeline.run(pipeline_ctx=context) + # Running the pipeline. - for linked in pipeline.run(input_data=[0]): + for linked in context.provide("result"): assert(isinstance(linked, TextOpinionsLinkage) or isinstance(linked, MetaEmptyLinkedDataWrapper)) if isinstance(linked, MetaEmptyLinkedDataWrapper): diff --git a/tests/tutorials/test_tutorial_pipeline_text_parser.py b/tests/tutorials/test_tutorial_pipeline_text_parser.py index aa272486..a4347ad1 100644 --- a/tests/tutorials/test_tutorial_pipeline_text_parser.py +++ b/tests/tutorials/test_tutorial_pipeline_text_parser.py @@ -53,7 +53,7 @@ def test(self): raise_error_on_existed_variant=False) text_parser = BaseTextParser(pipeline=[ - TextEntitiesParser(), + TextEntitiesParser(src_key="input", src_func=lambda s: s.Text), DefaultTextTokenizer(keep_tokens=True), LemmasBasedFrameVariantsParser(frame_variants=frame_variant_collection, stemmer=MystemWrapper()),