Skip to content

Commit

Permalink
✅ Fix unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: gkumbhat <[email protected]>
  • Loading branch information
gkumbhat authored and alex-jw-brooks committed Sep 12, 2023
1 parent 4bc9a16 commit 729f66a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
25 changes: 18 additions & 7 deletions caikit_nlp/modules/text_generation/text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,17 @@ def train(
"weight_decay": 0.01,
"save_total_limit": 3,
"push_to_hub": False,
"no_cuda": False, # Default
"no_cuda": not torch.cuda.is_available(), # Default
"remove_unused_columns": True,
"dataloader_pin_memory": False,
"gradient_accumulation_steps": accumulate_steps,
"eval_accumulation_steps": accumulate_steps,

"gradient_checkpointing": True,
"full_determinism": True,
# NOTE: This is explicitly set to false since it will
# negatively impact the performance
"full_determinism": False,
# Required for iterable dataset
"max_steps": 50,
"max_steps": 1,
# Some interesting parameters:
"auto_find_batch_size": True,
# NOTE: following can override above arguments in order
Expand All @@ -366,9 +368,15 @@ def train(
get_config().master_port,
)

torch.distributed.launcher.api.elastic_launch(
launch_config, cls._launch_training
)(base_model, training_dataset, training_args, checkpoint_dir)
if torch.cuda.is_available():
# NOTE: torch distributed can hang if run on CPUs,
# to avoid that, specially for unit tests, we are only
# running below when GPUs are available
torch.distributed.launcher.api.elastic_launch(
launch_config, cls._launch_training
)(base_model, training_dataset, training_args, checkpoint_dir)
else:
cls._launch_training(base_model, training_dataset, training_args, checkpoint_dir)

# In case this program is started via torchrun, below might not work as is
# because this case of multiple devices, this whole program gets run
Expand Down Expand Up @@ -568,6 +576,9 @@ def _launch_training(
trainer.save_state()
trainer.save_model(checkpoint_dir)

# save tokenizer explicitly
base_model.tokenizer.save_pretrained(checkpoint_dir)


def get(train_stream):
for data in train_stream:
Expand Down
11 changes: 9 additions & 2 deletions caikit_nlp/resources/pretrained_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,15 @@ def bootstrap(
padding_side=padding_side,
use_fast=False,
)
# set up the pad token if needed; note that this will mutate
# the tokenizer that is pass as an argument if one is provided.

# Load the tokenizer and set up the pad token if needed
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
local_files_only=not get_config().allow_downloads,
padding_side=padding_side,
# We can't disable use_fast otherwise unit test fails
# use_fast=False,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id

Expand Down
3 changes: 2 additions & 1 deletion tests/modules/text_generation/test_text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
SEQ2SEQ_LM_MODEL,
StubTGISBackend,
StubTGISClient,
set_cpu_device,
)

SAMPLE_TEXT = "Hello stub"
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_save_model_can_run():


@pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported")
def test_local_train_load_tgis():
def test_local_train_load_tgis(set_cpu_device):
"""Check if the model trained in local module is able to
be loaded in TGIS module / backend
"""
Expand Down

0 comments on commit 729f66a

Please sign in to comment.