Skip to content

Commit

Permalink
#539 done. Now pipelines way more flexible. Source could be customized.
Browse files Browse the repository at this point in the history
#535 related
  • Loading branch information
nicolay-r committed Dec 27, 2023
1 parent 0e4cc7f commit 958084c
Show file tree
Hide file tree
Showing 46 changed files with 251 additions and 191 deletions.
3 changes: 2 additions & 1 deletion arekit/common/docs/entities_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

class EntitiesGroupingPipelineItem(BasePipelineItem):

def __init__(self, value_to_group_id_func):
def __init__(self, value_to_group_id_func, **kwargs):
assert(callable(value_to_group_id_func))
super(EntitiesGroupingPipelineItem, self).__init__(**kwargs)
self.__value_to_group_id_func = value_to_group_id_func

def apply_core(self, input_data, pipeline_ctx):
Expand Down
13 changes: 6 additions & 7 deletions arekit/common/docs/objects_parser.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.common.text.partitioning.base import BasePartitioning
from arekit.common.pipeline.context import PipelineContext


class SentenceObjectsParserPipelineItem(BasePipelineItem):

def __init__(self, partitioning):
def __init__(self, partitioning, **kwargs):
assert(isinstance(partitioning, BasePartitioning))
super(SentenceObjectsParserPipelineItem, self).__init__(**kwargs)
self.__partitioning = partitioning

# region protected

def _get_text(self, pipeline_ctx):
def _get_text(self, sentence):
return None

def _get_parts_provider_func(self, input_data, pipeline_ctx):
def _get_parts_provider_func(self, sentence):
raise NotImplementedError()

# endregion

def apply_core(self, input_data, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
external_input = self._get_text(pipeline_ctx)
external_input = self._get_text(input_data)
actual_input = input_data if external_input is None else external_input
parts_it = self._get_parts_provider_func(input_data=actual_input, pipeline_ctx=pipeline_ctx)
parts_it = self._get_parts_provider_func(input_data)
return self.__partitioning.provide(text=actual_input, parts_it=parts_it)

# region base
Expand Down
12 changes: 1 addition & 11 deletions arekit/common/docs/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,9 @@ def parse(doc, text_parser, parent_ppl_ctx=None):
assert(isinstance(text_parser, BaseTextParser))
assert(isinstance(parent_ppl_ctx, PipelineContext) or parent_ppl_ctx is None)

parsed_sentences = [text_parser.run(input_data=DocumentParser.__get_sent(doc, sent_ind).Text,
params_dict=DocumentParser.__create_ppl_params(doc=doc, sent_ind=sent_ind),
parsed_sentences = [text_parser.run(params_dict={"input": DocumentParser.__get_sent(doc, sent_ind)},
parent_ctx=parent_ppl_ctx)
for sent_ind in range(doc.SentencesCount)]

return ParsedDocument(doc_id=doc.ID,
parsed_sentences=parsed_sentences)

@staticmethod
def __create_ppl_params(doc, sent_ind):
assert(isinstance(doc, Document))
return {
"s_ind": sent_ind, # sentence index. (as Metadata)
"doc_id": doc.ID, # document index. (as Metadata)
"sentence": DocumentParser.__get_sent(doc, sent_ind), # Required for special sources.
}
12 changes: 5 additions & 7 deletions arekit/common/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,15 @@ def __init__(self, pipeline):
assert(isinstance(pipeline, list))
self.__pipeline = pipeline

def run(self, input_data, params_dict=None, parent_ctx=None):
assert(isinstance(params_dict, dict) or params_dict is None)

pipeline_ctx = PipelineContext(d=params_dict if params_dict is not None else dict(),
parent_ctx=parent_ctx)
def run(self, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))

for item in filter(lambda itm: itm is not None, self.__pipeline):
assert(isinstance(item, BasePipelineItem))
input_data = item.apply(input_data=input_data, pipeline_ctx=pipeline_ctx)
item_result = item.apply(input_data=item.get_source(pipeline_ctx), pipeline_ctx=pipeline_ctx)
pipeline_ctx.update(param=item.ResultKey, value=item_result, is_new_key=False)

return input_data
return pipeline_ctx

def append(self, item):
assert(isinstance(item, BasePipelineItem))
Expand Down
6 changes: 5 additions & 1 deletion arekit/common/pipeline/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def __init__(self, d, parent_ctx=None):
self._d[PARENT_CTX] = parent_ctx

def __provide(self, param):
if param not in self._d:
raise Exception(f"Key `{param}` is not in dictionary.\n{self._d}")
return self._d[param]

# region public
Expand All @@ -23,7 +25,9 @@ def provide(self, param):
def provide_or_none(self, param):
return self.__provide(param) if param in self._d else None

def update(self, param, value):
def update(self, param, value, is_new_key=False):
if is_new_key and param in self._d:
raise Exception(f"Key `{param}` is already presented in pipeline context dictionary.")
self._d[param] = value

# endregion
Expand Down
30 changes: 30 additions & 0 deletions arekit/common/pipeline/items/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,37 @@
from arekit.common.pipeline.context import PipelineContext


class BasePipelineItem(object):
""" Single pipeline item that might be instatiated and embedded into pipeline.
"""

def __init__(self, src_key="result", result_key="result", src_func=None):
assert(isinstance(src_key, str) or src_key is None)
assert(callable(src_func) or src_func is None)
self.__src_key = src_key
self.__src_func = src_func
self.__result_key = result_key

@property
def ResultKey(self):
return self.__result_key

def get_source(self, src_ctx):
""" Extract input element for processing.
"""
assert(isinstance(src_ctx, PipelineContext))

# If there is no information about key, then we consider absence of the source.
if self.__src_key is None:
return None

# Extracting actual source.
src_data = src_ctx.provide(self.__src_key)
if self.__src_func is not None:
src_data = self.__src_func(src_data)

return src_data

def apply_core(self, input_data, pipeline_ctx):
raise NotImplementedError()

Expand Down
6 changes: 5 additions & 1 deletion arekit/common/pipeline/items/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ class FlattenIterPipelineItem(BasePipelineItem):
""" Considered to flat iterations of items that represent iterations.
"""

def __init__(self, **kwargs):
super(FlattenIterPipelineItem, self).__init__(**kwargs)
pass

def __flat_iter(self, iter_data):
for iter_item in iter_data:
for item in iter_item:
yield item

def apply_core(self, input_data, pipeline_ctx):
return self.__flat_iter(input_data)
return self.__flat_iter(input_data)
3 changes: 2 additions & 1 deletion arekit/common/pipeline/items/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

class HandleIterPipelineItem(BasePipelineItem):

def __init__(self, handle_func=None):
def __init__(self, handle_func=None, **kwargs):
assert(callable(handle_func))
super(HandleIterPipelineItem, self).__init__(**kwargs)
self.__handle_func = handle_func

def __updated_data(self, items_iter):
Expand Down
3 changes: 2 additions & 1 deletion arekit/common/pipeline/items/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

class FilterPipelineItem(BasePipelineItem):

def __init__(self, filter_func=None):
def __init__(self, filter_func=None, **kwargs):
assert(callable(filter_func))
super(FilterPipelineItem, self).__init__(**kwargs)
self.__filter_func = filter_func

def apply_core(self, input_data, pipeline_ctx):
Expand Down
3 changes: 2 additions & 1 deletion arekit/common/pipeline/items/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

class MapPipelineItem(BasePipelineItem):

def __init__(self, map_func=None):
def __init__(self, map_func=None, **kwargs):
assert(callable(map_func))
super(MapPipelineItem, self).__init__(**kwargs)
self._map_func = map_func

def apply_core(self, input_data, pipeline_ctx):
Expand Down
4 changes: 4 additions & 0 deletions arekit/common/pipeline/items/map_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,9 @@ class MapNestedPipelineItem(MapPipelineItem):
suppose to be mapped with the passed pipeline context.
"""

def __init__(self, **kwargs):
super(MapNestedPipelineItem, self).__init__(**kwargs)
pass

def apply_core(self, input_data, pipeline_ctx):
return map(lambda item: self._map_func(item, pipeline_ctx), input_data)
11 changes: 5 additions & 6 deletions arekit/common/text/parser.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.text.parsed import BaseParsedText


class BaseTextParser(BasePipeline):

def run(self, input_data, params_dict=None, parent_ctx=None):
output_data = super(BaseTextParser, self).run(input_data=input_data,
params_dict=params_dict,
parent_ctx=parent_ctx)

return BaseParsedText(terms=output_data)
def run(self, params_dict, parent_ctx=None):
assert(isinstance(params_dict, dict))
ctx = super(BaseTextParser, self).run(pipeline_ctx=PipelineContext(params_dict, parent_ctx=parent_ctx))
return BaseParsedText(terms=ctx.provide("result"))
18 changes: 4 additions & 14 deletions arekit/contrib/source/brat/entities/parser.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from arekit.common.docs.objects_parser import SentenceObjectsParserPipelineItem
from arekit.common.pipeline.context import PipelineContext
from arekit.common.text.partitioning.str import StringPartitioning
from arekit.common.text.partitioning.terms import TermsPartitioning
from arekit.contrib.source.brat.sentence import BratSentence


class BratTextEntitiesParser(SentenceObjectsParserPipelineItem):

KEY = "sentence"

################################
# NOTE: Supported partitionings.
################################
Expand All @@ -22,29 +19,22 @@ class BratTextEntitiesParser(SentenceObjectsParserPipelineItem):
"terms": TermsPartitioning()
}

def __init__(self, partitioning="string"):
def __init__(self, partitioning="string", **kwargs):
assert(isinstance(partitioning, str))
super(BratTextEntitiesParser, self).__init__(self.__supported_partitionings[partitioning])
super(BratTextEntitiesParser, self).__init__(self.__supported_partitionings[partitioning], **kwargs)

# region protected methods

def _get_text(self, pipeline_ctx):
sentence = self.__get_sentence(pipeline_ctx)
def _get_text(self, sentence):
return sentence.Text

def _get_parts_provider_func(self, input_data, pipeline_ctx):
sentence = self.__get_sentence(pipeline_ctx)
def _get_parts_provider_func(self, sentence):
return self.__iter_subs_values_with_bounds(sentence)

# endregion

# region private methods

def __get_sentence(self, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
assert(self.KEY in pipeline_ctx)
return pipeline_ctx.provide(self.KEY)

@staticmethod
def __iter_subs_values_with_bounds(sentence):
assert(isinstance(sentence, BratSentence))
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/source/ruattitudes/entity/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

class RuAttitudesTextEntitiesParser(BratTextEntitiesParser):

def __init__(self):
super(RuAttitudesTextEntitiesParser, self).__init__(partitioning="terms")
def __init__(self, **kwargs):
super(RuAttitudesTextEntitiesParser, self).__init__(partitioning="terms", **kwargs)
12 changes: 11 additions & 1 deletion arekit/contrib/utils/data/contents/opinions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from arekit.common.linkage.base import LinkedDataWrapper
from arekit.common.linkage.text_opinions import TextOpinionsLinkage
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.text_opinions.base import TextOpinion


Expand Down Expand Up @@ -30,7 +31,16 @@ def __assign_ids(self, linkage):

def from_doc_ids(self, doc_ids, idle_mode=False):
self.__current_id = 0
for linkage in self.__pipeline.run(doc_ids, params_dict={IDLE_MODE: idle_mode}):

ctx = PipelineContext(d={
"result": doc_ids,
IDLE_MODE: idle_mode
})

# Launching pipeline with the passed context
self.__pipeline.run(ctx)

for linkage in ctx.provide("result"):
assert(isinstance(linkage, LinkedDataWrapper))
if isinstance(linkage, TextOpinionsLinkage):
self.__assign_ids(linkage)
Expand Down
15 changes: 6 additions & 9 deletions arekit/contrib/utils/pipelines/items/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class BaseSerializerPipelineItem(BasePipelineItem):

def __init__(self, rows_provider, samples_io, save_labels_func, storage):
def __init__(self, rows_provider, samples_io, save_labels_func, storage, **kwargs):
""" sample_rows_formatter:
how we format input texts for a BERT model, for example:
- single text
Expand All @@ -23,6 +23,7 @@ def __init__(self, rows_provider, samples_io, save_labels_func, storage):
assert(isinstance(samples_io, BaseSamplesIO))
assert(callable(save_labels_func))
assert(isinstance(storage, BaseRowsStorage))
super(BaseSerializerPipelineItem, self).__init__(**kwargs)

self._rows_provider = rows_provider
self._samples_io = samples_io
Expand Down Expand Up @@ -89,11 +90,7 @@ def apply_core(self, input_data, pipeline_ctx):
doc_ids: optional
this parameter allows to limit amount of documents considered for sampling
"""
assert(isinstance(input_data, PipelineContext))
assert("data_type_pipelines" in input_data)

data_folding = input_data.provide_or_none("data_folding")

self._handle_iteration(data_type_pipelines=input_data.provide("data_type_pipelines"),
doc_ids=input_data.provide_or_none("doc_ids"),
data_folding=data_folding)
assert("data_type_pipelines" in pipeline_ctx)
self._handle_iteration(data_type_pipelines=pipeline_ctx.provide("data_type_pipelines"),
doc_ids=pipeline_ctx.provide_or_none("doc_ids"),
data_folding=pipeline_ctx.provide_or_none("data_folding"))
5 changes: 3 additions & 2 deletions arekit/contrib/utils/pipelines/items/sampling/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class NetworksInputSerializerPipelineItem(BaseSerializerPipelineItem):

def __init__(self, save_labels_func, rows_provider, samples_io, emb_io, storage, save_embedding=True):
def __init__(self, save_labels_func, rows_provider, samples_io, emb_io, storage, save_embedding=True, **kwargs):
""" This pipeline item allows to perform a data preparation for neural network models.
considering a list of the whole data_types with the related pipelines,
Expand All @@ -23,7 +23,8 @@ def __init__(self, save_labels_func, rows_provider, samples_io, emb_io, storage,
rows_provider=rows_provider,
samples_io=samples_io,
save_labels_func=save_labels_func,
storage=storage)
storage=storage,
**kwargs)

self.__emb_io = emb_io
self.__save_embedding = save_embedding
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/utils/pipelines/items/text/entities_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

class TextEntitiesParser(BasePipelineItem):

def __init__(self):
super(TextEntitiesParser, self).__init__()
def __init__(self, **kwargs):
super(TextEntitiesParser, self).__init__(**kwargs)

@staticmethod
def __process_word(word):
Expand Down
5 changes: 2 additions & 3 deletions arekit/contrib/utils/pipelines/items/text/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

class FrameVariantsParser(BasePipelineItem):

def __init__(self, frame_variants):
def __init__(self, frame_variants, **kwargs):
assert(isinstance(frame_variants, FrameVariantsCollection))
assert(len(frame_variants) > 0)

super(FrameVariantsParser, self).__init__()
super(FrameVariantsParser, self).__init__(**kwargs)

self.__frame_variants = frame_variants
self.__max_variant_len = max([len(variant) for _, variant in frame_variants.iter_variants()])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

class LemmasBasedFrameVariantsParser(FrameVariantsParser):

def __init__(self, frame_variants, stemmer, locale_mods=RussianLanguageMods, save_lemmas=False):
def __init__(self, frame_variants, stemmer, locale_mods=RussianLanguageMods, save_lemmas=False, **kwargs):
assert(isinstance(stemmer, Stemmer))
assert(isinstance(save_lemmas, bool))
super(LemmasBasedFrameVariantsParser, self).__init__(frame_variants=frame_variants)
super(LemmasBasedFrameVariantsParser, self).__init__(frame_variants=frame_variants, **kwargs)

self.__frame_variants = frame_variants
self.__stemmer = stemmer
Expand Down
Loading

0 comments on commit 958084c

Please sign in to comment.