Skip to content

Commit

Permalink
✅ Add test for training data validation
Browse files Browse the repository at this point in the history
Signed-off-by: gkumbhat <[email protected]>
  • Loading branch information
gkumbhat committed Nov 1, 2023
1 parent 23f455a commit 4cd5d61
Showing 1 changed file with 112 additions and 0 deletions.
112 changes: 112 additions & 0 deletions tests/modules/text_generation/test_peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
seq2seq_lm_dummy_model,
seq2seq_lm_train_kwargs,
set_cpu_device,
temp_config,
)
import caikit_nlp

Expand Down Expand Up @@ -399,3 +400,114 @@ def test_run_exponential_decay_len_penatly_object(causal_lm_dummy_model):
exponential_decay_length_penalty=penalty,
)
assert isinstance(pred, GeneratedTextResult)


def test_train_with_data_validation_raises(causal_lm_train_kwargs, set_cpu_device):
"""Check if we are able to throw error for when number of examples are more than configured limit"""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
ClassificationTrainRecord(
text="@foo what a cute dog!", labels=["no complaint"]
),
ClassificationTrainRecord(
text="@bar this is the worst idea ever.", labels=["complaint"]
),
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
}
causal_lm_train_kwargs.update(patch_kwargs)

model_name = causal_lm_train_kwargs["base_model"]._model_name
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(training_data_limit={module.MODULE_ID: {model_name: 1}}):
with pytest.raises(ValueError):
module.train(
**causal_lm_train_kwargs
)

def test_train_with_data_validation_success(causal_lm_train_kwargs, set_cpu_device):
"""Check if we are able to train successfully if training data is within limits"""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
ClassificationTrainRecord(
text="@foo what a cute dog!", labels=["no complaint"]
),
ClassificationTrainRecord(
text="@bar this is the worst idea ever.", labels=["complaint"]
),
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
}
causal_lm_train_kwargs.update(patch_kwargs)

model_name = causal_lm_train_kwargs["base_model"]._model_name
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(training_data_limit={module.MODULE_ID: {model_name: 2}}):

model = module.train(
**causal_lm_train_kwargs
)
assert model

def test_train_with_non_existent_limit_success(causal_lm_train_kwargs, set_cpu_device):
"""Check if we are able to train successfully if training data limit doesn't exist for particular model"""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
ClassificationTrainRecord(
text="@foo what a cute dog!", labels=["no complaint"]
)
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
}
causal_lm_train_kwargs.update(patch_kwargs)

model_name = causal_lm_train_kwargs["base_model"]._model_name
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(training_data_limit={module.MODULE_ID: {"foo": 2}}):

model = module.train(
**causal_lm_train_kwargs
)
assert model


def test_train_with_no_limit_for_module(causal_lm_train_kwargs, set_cpu_device):
"""Check if we are able to train successfully if training data limit doesn't exist prompt tuning module"""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
ClassificationTrainRecord(
text="@foo what a cute dog!", labels=["no complaint"]
)
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
}
causal_lm_train_kwargs.update(patch_kwargs)

model_name = causal_lm_train_kwargs["base_model"]._model_name
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(training_data_limit={}):

model = module.train(
**causal_lm_train_kwargs
)
assert model

0 comments on commit 4cd5d61

Please sign in to comment.