Skip to content

Commit

Permalink
🎨 Fix formatting
Browse files Browse the repository at this point in the history
Signed-off-by: gkumbhat <[email protected]>
  • Loading branch information
gkumbhat committed Nov 1, 2023
1 parent 4cd5d61 commit c7a0447
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
8 changes: 6 additions & 2 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,13 +368,17 @@ def train(
)

# Check if data is within limit allowed for this module and model
max_num_examples = get_config().training_data_limit.get(cls.MODULE_ID, {}).get(base_model_name, -1)
max_num_examples = (
get_config()
.training_data_limit.get(cls.MODULE_ID, {})
.get(base_model_name, -1)
)

if max_num_examples > 0:
error.value_check(
"<NLP77627434E>",
len(train_stream) <= max_num_examples,
"Number of examples larger than maximum number of examples allowed for this model"
"Number of examples larger than maximum number of examples allowed for this model",
)

# Coerce the passed model into a resource; if we have one, this is a noop
Expand Down
20 changes: 7 additions & 13 deletions tests/modules/text_generation/test_peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,8 @@ def test_train_with_data_validation_raises(causal_lm_train_kwargs, set_cpu_devic
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(training_data_limit={module.MODULE_ID: {model_name: 1}}):
with pytest.raises(ValueError):
module.train(
**causal_lm_train_kwargs
)
module.train(**causal_lm_train_kwargs)


def test_train_with_data_validation_success(causal_lm_train_kwargs, set_cpu_device):
"""Check if we are able to train successfully if training data is within limits"""
Expand All @@ -454,11 +453,10 @@ def test_train_with_data_validation_success(causal_lm_train_kwargs, set_cpu_devi
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(training_data_limit={module.MODULE_ID: {model_name: 2}}):

model = module.train(
**causal_lm_train_kwargs
)
model = module.train(**causal_lm_train_kwargs)
assert model


def test_train_with_non_existent_limit_success(causal_lm_train_kwargs, set_cpu_device):
"""Check if we are able to train successfully if training data limit doesn't exist for particular model"""
patch_kwargs = {
Expand All @@ -480,9 +478,7 @@ def test_train_with_non_existent_limit_success(causal_lm_train_kwargs, set_cpu_d
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(training_data_limit={module.MODULE_ID: {"foo": 2}}):

model = module.train(
**causal_lm_train_kwargs
)
model = module.train(**causal_lm_train_kwargs)
assert model


Expand All @@ -507,7 +503,5 @@ def test_train_with_no_limit_for_module(causal_lm_train_kwargs, set_cpu_device):
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(training_data_limit={}):

model = module.train(
**causal_lm_train_kwargs
)
assert model
model = module.train(**causal_lm_train_kwargs)
assert model

0 comments on commit c7a0447

Please sign in to comment.