diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 5b18f250..afe2033d 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -41,6 +41,7 @@ import transformers # First Party +from caikit import get_config from caikit.core.data_model import DataStream from caikit.core.exceptions import error_handler from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module @@ -353,7 +354,6 @@ def train( # HACK - These things can't be passed through the train API currently - breakpoint() metric = kwargs.get("metric") base_model = resolve_base_model(base_model, cls, torch_dtype) @@ -367,6 +367,16 @@ def train( verbalizer, ) + # Check if data is within limit allowed for this module and model + max_num_examples = get_config().training_data_limit.get(cls.MODULE_ID, {}).get(base_model_name, -1) + + if max_num_examples > 0: + error.value_check( + "", + len(train_stream) <= max_num_examples, + "Number of examples larger than maximum number of examples allowed for this model" + ) + # Coerce the passed model into a resource; if we have one, this is a noop # TODO: When splitting up this mono-module, use the configured resource # type of the concrete class to bootstrap