diff --git a/caikit_nlp/config/config.yml b/caikit_nlp/config/config.yml index 8b64959e..92a72d9b 100644 --- a/caikit_nlp/config/config.yml +++ b/caikit_nlp/config/config.yml @@ -31,6 +31,7 @@ master_addr: localhost master_port: 29550 training_data_limit: + __default__: -1 # Configuration for PeftPromptTuning module 6655831b-960a-4dc5-8df4-867026e2cd41: add_model_name_here: 10000 diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index c10a58eb..307ff6cc 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -41,7 +41,6 @@ import transformers # First Party -from caikit import get_config from caikit.core.data_model import DataStream from caikit.core.exceptions import error_handler from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module @@ -73,6 +72,7 @@ generate_text_func, generate_text_func_stream, ) +from ...toolkit.trainer_utils import validate_training_data from ...toolkit.verbalizer_utils import render_verbalizer from .peft_config import TuningType, get_peft_config, resolve_base_model @@ -368,19 +368,12 @@ def train( ) # Check if data is within limit allowed for this module and model - max_num_examples = ( - get_config() - .training_data_limit.get(cls.MODULE_ID, {}) - .get(base_model_name, -1) + validate_training_data( + train_stream, + base_model_name, + cls.MODULE_ID, ) - if max_num_examples > 0: - error.value_check( - "", - len(train_stream) <= max_num_examples, - "Number of examples larger than maximum number of examples allowed for this model", - ) - # Coerce the passed model into a resource; if we have one, this is a noop # TODO: When splitting up this mono-module, use the configured resource # type of the concrete class to bootstrap diff --git a/caikit_nlp/toolkit/trainer_utils.py b/caikit_nlp/toolkit/trainer_utils.py index 736e9d90..3ce41d58 100644 --- a/caikit_nlp/toolkit/trainer_utils.py +++ b/caikit_nlp/toolkit/trainer_utils.py @@ -19,9 +19,36 @@ import torch # First Party +from caikit import get_config +from caikit.core.data_model import DataStream +from caikit.core.exceptions import error_handler import alog log = alog.use_channel("TRNR_UTILS") +error = error_handler.get(log) + + +def validate_training_data(train_stream: DataStream, model_name: str, module_id: str): + + global_default = get_config().training_data_limit.__default__ + module_default = ( + get_config() + .training_data_limit.get(module_id, {}) + .get("__default__", global_default) + ) + + max_num_examples = ( + get_config() + .training_data_limit.get(module_id, {}) + .get(model_name, module_default) + ) + + if max_num_examples > 0: + error.value_check( + "", + len(train_stream) <= max_num_examples, + "Number of examples larger than maximum number of examples allowed for this model", + ) def log_step(state, logs): diff --git a/pyproject.toml b/pyproject.toml index 347ae392..e99ce0b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.23.2,<0.25.0", + "caikit[runtime-grpc,runtime-http]>=0.24.0,<0.25.0", "caikit-tgis-backend>=0.1.17,<0.2.0", # TODO: loosen dependencies "accelerate>=0.22.0", diff --git a/tests/modules/text_generation/test_peft_prompt_tuning.py b/tests/modules/text_generation/test_peft_prompt_tuning.py index 23da08e0..2ade5a88 100644 --- a/tests/modules/text_generation/test_peft_prompt_tuning.py +++ b/tests/modules/text_generation/test_peft_prompt_tuning.py @@ -505,3 +505,135 @@ def test_train_with_no_limit_for_module(causal_lm_train_kwargs, set_cpu_device): model = module.train(**causal_lm_train_kwargs) assert model + + +def test_train_module_level_data_validation_raises( + causal_lm_train_kwargs, set_cpu_device +): + """Check if train raises with module level default configuration + if training data is within limits and model config is not provided + """ + patch_kwargs = { + "num_epochs": 1, + "verbalizer": "Tweet text : {{input}} Label : ", + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + ClassificationTrainRecord( + text="@foo what a cute dog!", labels=["no complaint"] + ), + ClassificationTrainRecord( + text="@bar this is the worst idea ever.", labels=["complaint"] + ), + ] + ), + "torch_dtype": torch.bfloat16, + "device": "cpu", + } + causal_lm_train_kwargs.update(patch_kwargs) + + module = caikit_nlp.modules.text_generation.PeftPromptTuning + with temp_config( + training_data_limit={module.MODULE_ID: {"__default__": 1, "foo": 2}} + ): + with pytest.raises(ValueError): + module.train(**causal_lm_train_kwargs) + + +def test_train_module_level_data_validation_success( + causal_lm_train_kwargs, set_cpu_device +): + """Check if we are able to train successfully with module level default configuration + if training data is within limits and model config present + """ + patch_kwargs = { + "num_epochs": 1, + "verbalizer": "Tweet text : {{input}} Label : ", + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + ClassificationTrainRecord( + text="@foo what a cute dog!", labels=["no complaint"] + ), + ClassificationTrainRecord( + text="@bar this is the worst idea ever.", labels=["complaint"] + ), + ] + ), + "torch_dtype": torch.bfloat16, + "device": "cpu", + } + causal_lm_train_kwargs.update(patch_kwargs) + + model_name = causal_lm_train_kwargs["base_model"]._model_name + module = caikit_nlp.modules.text_generation.PeftPromptTuning + with temp_config( + training_data_limit={module.MODULE_ID: {"__default__": 1, model_name: 2}} + ): + + model = module.train(**causal_lm_train_kwargs) + assert model + + +def test_train_global_default_data_validation_raises( + causal_lm_train_kwargs, set_cpu_device +): + """Check if train raises with global default configuration + if training data is within limits and model config is not provided + """ + patch_kwargs = { + "num_epochs": 1, + "verbalizer": "Tweet text : {{input}} Label : ", + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + ClassificationTrainRecord( + text="@foo what a cute dog!", labels=["no complaint"] + ), + ClassificationTrainRecord( + text="@bar this is the worst idea ever.", labels=["complaint"] + ), + ] + ), + "torch_dtype": torch.bfloat16, + "device": "cpu", + } + causal_lm_train_kwargs.update(patch_kwargs) + + module = caikit_nlp.modules.text_generation.PeftPromptTuning + with temp_config( + training_data_limit={"__default__": 1, module.MODULE_ID: {"foo": 2}} + ): + with pytest.raises(ValueError): + module.train(**causal_lm_train_kwargs) + + +def test_train_global_default_data_validation_success( + causal_lm_train_kwargs, set_cpu_device +): + """Check if we are able to train successfully with global default configuration + if training data is within limits and model config is present + """ + patch_kwargs = { + "num_epochs": 1, + "verbalizer": "Tweet text : {{input}} Label : ", + "train_stream": caikit.core.data_model.DataStream.from_iterable( + [ + ClassificationTrainRecord( + text="@foo what a cute dog!", labels=["no complaint"] + ), + ClassificationTrainRecord( + text="@bar this is the worst idea ever.", labels=["complaint"] + ), + ] + ), + "torch_dtype": torch.bfloat16, + "device": "cpu", + } + causal_lm_train_kwargs.update(patch_kwargs) + + model_name = causal_lm_train_kwargs["base_model"]._model_name + module = caikit_nlp.modules.text_generation.PeftPromptTuning + with temp_config( + training_data_limit={"__default__": 1, module.MODULE_ID: {model_name: 2}} + ): + + model = module.train(**causal_lm_train_kwargs) + assert model