Skip to content

Commit

Permalink
#547 done
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jan 8, 2024
1 parent cc759ad commit c0d600b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 53 deletions.
48 changes: 6 additions & 42 deletions arekit/contrib/utils/io_utils/samples.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import logging
from os.path import join

from arekit.contrib.utils.data.readers.base import BaseReader
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.contrib.utils.data.writers.base import BaseWriter
from arekit.contrib.utils.io_utils.utils import check_targets_existence

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand All @@ -13,32 +11,25 @@
class SamplesIO(BaseSamplesIO):
""" Samples default IO utils for samples.
Sample is a text part which include pair of attitude participants.
This class allows to provide saver and loader for such entries, bubbed as samples.
This class allows to provide saver and loader for such entries, dubbed as samples.
Samples required for machine learning training/inferring.
"""

def __init__(self, target_dir, writer=None, reader=None, prefix="sample"):
assert(isinstance(target_dir, str))
assert(isinstance(prefix, str))
def __init__(self, create_target_func, writer=None, reader=None):
assert(isinstance(writer, BaseWriter) or writer is None)
assert(isinstance(reader, BaseReader) or reader is None)
self.__target_dir = target_dir
self.__prefix = prefix
assert(callable(create_target_func))

self.__writer = writer
self.__reader = reader
self.__create_target_func = create_target_func

self.__target_extension = None
if writer is not None:
self.__target_extension = writer.extension()
elif reader is not None:
self.__target_extension = reader.extension()

# region public methods

@property
def Prefix(self):
return self.__prefix

@property
def Reader(self):
return self.__reader
Expand All @@ -48,31 +39,4 @@ def Writer(self):
return self.__writer

def create_target(self, data_type):
return self.__get_input_sample_target(data_type)

def check_targets_existed(self, data_types_iter):
for data_type in data_types_iter:

targets = [
self.__get_input_sample_target(data_type=data_type),
]

if not check_targets_existence(targets=targets):
return False
return True

# endregion

def __get_input_sample_target(self, data_type):
return self.__get_filepath(out_dir=self.__target_dir,
template=f"{data_type.name.lower()}",
prefix=self.__prefix,
extension=self.__target_extension)

@staticmethod
def __get_filepath(out_dir, template, prefix, extension):
assert(isinstance(template, str))
assert(isinstance(prefix, str))
assert(isinstance(extension, str))
return join(out_dir, "{prefix}-{template}{extension}".format(
prefix=prefix, template=template, extension=extension))
return self.__create_target_func(data_type) + self.__target_extension
12 changes: 1 addition & 11 deletions arekit/contrib/utils/io_utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
from collections.abc import Iterable
import logging
from os.path import join, exists
from os.path import exists


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def join_dir_with_subfolder_name(subfolder_name, dir):
""" Returns subfolder in in directory
"""
assert(isinstance(subfolder_name, str))
assert(isinstance(dir, str))

target_dir = join(dir, "{}/".format(subfolder_name))
return target_dir


def check_targets_existence(targets):
assert (isinstance(targets, Iterable))

Expand Down

0 comments on commit c0d600b

Please sign in to comment.