From 4bc9a168c87dc52607678ba8feaaaf572ac9483a Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 15 Aug 2023 08:12:08 -0500 Subject: [PATCH] Handle substreams for fine tuning data prep Signed-off-by: Alex-Brooks --- .../text_generation/text_generation_local.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index cc5ef81f..d88ed761 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -514,8 +514,6 @@ def _preprocess_function( use_iterable_dataset: bool, ): """Pre-process each example to get it prepared for training.""" - if base_model.REQUIRES_TOKEN_UNWRAPPING: - raise NotImplementedError("Token unwrapping not implemented for fine tuning data prep") if use_iterable_dataset: # Generator based log.debug("Loading data as an iterable dataset") @@ -525,11 +523,18 @@ def _preprocess_function( else: # Convert the train stream to an normal dataset in memory log.debug("Loading data as a normal dataset") + # TODO: Optimize and clean this up! inputs = [] outputs = [] - for datum in train_stream: - inputs.append(datum.input) - outputs.append(datum.output) + if base_model.REQUIRES_TOKEN_UNWRAPPING: + for substream in train_stream: + for data in substream: + inputs.append(data.input) + outputs.append(data.output) + else: + for data in train_stream: + inputs.append(data.input) + outputs.append(data.output) dataset = Dataset.from_dict({"input": inputs, "output": outputs}) # Map our HF datasets; with our tokenizer functions mapped_dataset = dataset.map( @@ -566,4 +571,10 @@ def _launch_training( def get(train_stream): for data in train_stream: - yield {"input": data.input, "output": data.output} + # Handle token unwrapping for causal language modeling + if isinstance(data, DataStream): + for datum in data: + yield {"input": datum.input, "output": datum.output} + # Otherwise assume we directly yield dictionaries + else: + yield {"input": data.input, "output": data.output}