Skip to content

Commit

Permalink
Merge pull request caikit#255 from gkumbhat/add_data_limitation
Browse files Browse the repository at this point in the history
Add data limitation
  • Loading branch information
gkumbhat committed Nov 2, 2023
2 parents e8d176e + 7834fae commit e512728
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 2 deletions.
5 changes: 5 additions & 0 deletions caikit_nlp/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,10 @@ unload_tgis_prompt_artifacts: false
master_addr: localhost
master_port: 29550

training_data_limit:
# Configuration for PeftPromptTuning module
6655831b-960a-4dc5-8df4-867026e2cd41:
add_model_name_here: 10000

runtime:
library: caikit_nlp
15 changes: 15 additions & 0 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import transformers

# First Party
from caikit import get_config
from caikit.core.data_model import DataStream
from caikit.core.exceptions import error_handler
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
Expand Down Expand Up @@ -366,6 +367,20 @@ def train(
verbalizer,
)

# Check if data is within limit allowed for this module and model
max_num_examples = (
get_config()
.training_data_limit.get(cls.MODULE_ID, {})
.get(base_model_name, -1)
)

if max_num_examples > 0:
error.value_check(
"<NLP77627434E>",
len(train_stream) <= max_num_examples,
"Number of examples larger than maximum number of examples allowed for this model",
)

# Coerce the passed model into a resource; if we have one, this is a noop
# TODO: When splitting up this mono-module, use the configured resource
# type of the concrete class to bootstrap
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers=[
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"caikit[runtime-grpc,runtime-http]>=0.22.0,<0.23.0",
"caikit[runtime-grpc,runtime-http]>=0.23.2,<0.25.0",
"caikit-tgis-backend>=0.1.17,<0.2.0",
# TODO: loosen dependencies
"accelerate>=0.22.0",
Expand All @@ -32,7 +32,7 @@ dependencies = [
# which broke caikit-nlp build. peft hasn't released newer version yet, so to get
# the build fix, we pulling peft from main branch commit. In future, we will pull PEFT from
# pypi
"peft@git+https://github.com/huggingface/peft.git#8c17d556a8fe9522e10d73d7bd3fad46a6ecae14"
"peft@git+https://github.com/huggingface/peft.git@8c17d556a8fe9522e10d73d7bd3fad46a6ecae14"
]

[tool.setuptools.packages.find]
Expand Down
106 changes: 106 additions & 0 deletions tests/modules/text_generation/test_peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
seq2seq_lm_dummy_model,
seq2seq_lm_train_kwargs,
set_cpu_device,
temp_config,
)
import caikit_nlp

Expand Down Expand Up @@ -399,3 +400,108 @@ def test_run_exponential_decay_len_penatly_object(causal_lm_dummy_model):
exponential_decay_length_penalty=penalty,
)
assert isinstance(pred, GeneratedTextResult)


def test_train_with_data_validation_raises(causal_lm_train_kwargs, set_cpu_device):
"""Check if we are able to throw error for when number of examples are more than configured limit"""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
ClassificationTrainRecord(
text="@foo what a cute dog!", labels=["no complaint"]
),
ClassificationTrainRecord(
text="@bar this is the worst idea ever.", labels=["complaint"]
),
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
}
causal_lm_train_kwargs.update(patch_kwargs)

model_name = causal_lm_train_kwargs["base_model"]._model_name
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(training_data_limit={module.MODULE_ID: {model_name: 1}}):
with pytest.raises(ValueError):
module.train(**causal_lm_train_kwargs)


def test_train_with_data_validation_success(causal_lm_train_kwargs, set_cpu_device):
"""Check if we are able to train successfully if training data is within limits"""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
ClassificationTrainRecord(
text="@foo what a cute dog!", labels=["no complaint"]
),
ClassificationTrainRecord(
text="@bar this is the worst idea ever.", labels=["complaint"]
),
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
}
causal_lm_train_kwargs.update(patch_kwargs)

model_name = causal_lm_train_kwargs["base_model"]._model_name
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(training_data_limit={module.MODULE_ID: {model_name: 2}}):

model = module.train(**causal_lm_train_kwargs)
assert model


def test_train_with_non_existent_limit_success(causal_lm_train_kwargs, set_cpu_device):
"""Check if we are able to train successfully if training data limit doesn't exist for particular model"""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
ClassificationTrainRecord(
text="@foo what a cute dog!", labels=["no complaint"]
)
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
}
causal_lm_train_kwargs.update(patch_kwargs)

model_name = causal_lm_train_kwargs["base_model"]._model_name
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(training_data_limit={module.MODULE_ID: {"foo": 2}}):

model = module.train(**causal_lm_train_kwargs)
assert model


def test_train_with_no_limit_for_module(causal_lm_train_kwargs, set_cpu_device):
"""Check if we are able to train successfully if training data limit doesn't exist prompt tuning module"""
patch_kwargs = {
"num_epochs": 1,
"verbalizer": "Tweet text : {{input}} Label : ",
"train_stream": caikit.core.data_model.DataStream.from_iterable(
[
ClassificationTrainRecord(
text="@foo what a cute dog!", labels=["no complaint"]
)
]
),
"torch_dtype": torch.bfloat16,
"device": "cpu",
}
causal_lm_train_kwargs.update(patch_kwargs)

model_name = causal_lm_train_kwargs["base_model"]._model_name
module = caikit_nlp.modules.text_generation.PeftPromptTuning
with temp_config(training_data_limit={}):

model = module.train(**causal_lm_train_kwargs)
assert model

0 comments on commit e512728

Please sign in to comment.