diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index eb100020..63145841 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -554,6 +554,8 @@ def _preprocess_function( "max_target_length": max_target_length, }, ) + if shuffle: + return mapped_dataset.shuffle() return mapped_dataset @staticmethod