Skip to content

Commit

Permalink
✨ Add training data limit check for prompt tuning
Browse files Browse the repository at this point in the history
Signed-off-by: gkumbhat <[email protected]>
  • Loading branch information
gkumbhat committed Nov 1, 2023
1 parent 5edaa00 commit 23f455a
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import transformers

# First Party
from caikit import get_config
from caikit.core.data_model import DataStream
from caikit.core.exceptions import error_handler
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
Expand Down Expand Up @@ -353,7 +354,6 @@ def train(

# HACK - These things can't be passed through the train API currently

breakpoint()
metric = kwargs.get("metric")

base_model = resolve_base_model(base_model, cls, torch_dtype)
Expand All @@ -367,6 +367,16 @@ def train(
verbalizer,
)

# Check if data is within limit allowed for this module and model
max_num_examples = get_config().training_data_limit.get(cls.MODULE_ID, {}).get(base_model_name, -1)

if max_num_examples > 0:
error.value_check(
"<NLP77627434E>",
len(train_stream) <= max_num_examples,
"Number of examples larger than maximum number of examples allowed for this model"
)

# Coerce the passed model into a resource; if we have one, this is a noop
# TODO: When splitting up this mono-module, use the configured resource
# type of the concrete class to bootstrap
Expand Down

0 comments on commit 23f455a

Please sign in to comment.