From 77ab9bdc10955eced70b2f60f87486c8caa21330 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Mon, 18 Sep 2023 18:51:57 -0500 Subject: [PATCH] :coffin: Remove breakpoints and improve comments Signed-off-by: gkumbhat --- .../text_generation/peft_prompt_tuning.py | 27 +++++++++++-------- caikit_nlp/resources/pretrained_model/base.py | 3 +-- examples/run_peft_tuning.py | 2 +- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index a4a5700e..61d131b3 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -439,7 +439,6 @@ def train( log.debug("Bootstrapping base resource [%s]", base_model) base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) - breakpoint() error.type_check("", PretrainedModelBase, base_model=base_model) # Validate if tuned output model type is compatible with base model or not @@ -545,7 +544,7 @@ def train( # transformers model) to the right underlying type. device = cls._get_device(device) # cls.convert_peft_model_to_type(device, peft_model, torch_dtype) - breakpoint() + cls._execute_train_loop( peft_model, num_epochs, @@ -557,6 +556,7 @@ def train( tokenizer=base_model.tokenizer, accumulate_steps=accumulate_steps, silence_progress_bars=silence_progress_bars, + torch_dtype=torch_dtype, ) # Get config of the base model @@ -1096,6 +1096,7 @@ def _execute_train_loop( tokenizer: Union[AutoTokenizer, None] = None, accumulate_steps: int = 1, silence_progress_bars: bool = True, + torch_dtype: "torch.dtype" = torch.float32, ) -> None: """Execute the core training logic for training the prompt vectors on the frozen model. Note that this is done by reference. @@ -1124,6 +1125,8 @@ def _execute_train_loop( Number of steps to use for gradient accumulation. Default: 1. silence_progress_bars: bool Silences TQDM progress bars. Default: True + torch_dtype: torch.dtype + Dtype to be used for training. Default: torch.float32 """ optimizer = AdamW(params=model.parameters(), lr=learning_rate) lr_scheduler = get_linear_schedule_with_warmup( @@ -1135,17 +1138,24 @@ def _execute_train_loop( # Enable gradient checkpointing model.gradient_checkpointing_enable() + if torch_dtype == torch.float16: + mixed_precision = "fp16" + elif torch_dtype == torch.bfloat16 and torch.cuda.is_bf16_supported(): + mixed_precision = "bf16" + else: + mixed_precision = "no" + accelerator = Accelerator( gradient_accumulation_steps=accumulate_steps, device_placement=True, - mixed_precision='bf16', + mixed_precision=mixed_precision, ) # Disable cache for training model.config.use_cache=False - # model.to(accelerator.device, torch.bfloat16) - breakpoint() + # Below would send all the data and model to + # configured device and convert them to required dtypes model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler, ) @@ -1159,13 +1169,8 @@ def _execute_train_loop( tqdm_loader.set_description("Epoch: {}".format(epoch)) # TODO Can this dict comprehension always replace "batch.to(device)" for us? - # TODO: Try catching the error using torch.cuda.OutOfMemoryError - # batch = {k: v.to(device) for k, v in batch.items()} - breakpoint() try: - with accelerator.accumulate(model): - breakpoint() outputs = model(**batch) loss = outputs.loss total_loss += loss.detach().float() @@ -1173,7 +1178,7 @@ def _execute_train_loop( optimizer.step() lr_scheduler.step() optimizer.zero_grad() - except torch.cuda.OutOfMemoryError as ex: + except torch.cuda.OutOfMemoryError: error("", MemoryError("Not enough memory available for training!")) log.info("epoch %s: %s", epoch, loss) diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index 1303ba14..3f6aa498 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -194,8 +194,7 @@ def bootstrap( model = cls.MODEL_TYPE.from_pretrained( model_name, local_files_only=not get_config().allow_downloads, - # FIXME: Seems like torch.dtype isn't working correctly - torch_dtype=torch.bfloat16, + torch_dtype=torch_dtype, **kwargs, ) log.debug4("Model Details: %s", model) diff --git a/examples/run_peft_tuning.py b/examples/run_peft_tuning.py index 88b5b7dd..9de9fc2e 100644 --- a/examples/run_peft_tuning.py +++ b/examples/run_peft_tuning.py @@ -407,7 +407,7 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None: verbalizer=dataset_info.verbalizer, silence_progress_bars=not args.verbose, accumulate_steps=args.accumulate_steps, - torch_dtype="bfloat16" + torch_dtype="float16" ) model.save(args.output_dir, save_base_model=not args.prompt_only) print_colored("[Training Complete]")