Skip to content

Commit

Permalink
⚰️ Remove breakpoints and improve comments
Browse files Browse the repository at this point in the history
Signed-off-by: gkumbhat <[email protected]>
  • Loading branch information
gkumbhat committed Sep 18, 2023
1 parent b11e4cb commit 77ab9bd
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
27 changes: 16 additions & 11 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,6 @@ def train(
log.debug("Bootstrapping base resource [%s]", base_model)
base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype)

breakpoint()
error.type_check("<NLP65714919E>", PretrainedModelBase, base_model=base_model)

# Validate if tuned output model type is compatible with base model or not
Expand Down Expand Up @@ -545,7 +544,7 @@ def train(
# transformers model) to the right underlying type.
device = cls._get_device(device)
# cls.convert_peft_model_to_type(device, peft_model, torch_dtype)
breakpoint()

cls._execute_train_loop(
peft_model,
num_epochs,
Expand All @@ -557,6 +556,7 @@ def train(
tokenizer=base_model.tokenizer,
accumulate_steps=accumulate_steps,
silence_progress_bars=silence_progress_bars,
torch_dtype=torch_dtype,
)

# Get config of the base model
Expand Down Expand Up @@ -1096,6 +1096,7 @@ def _execute_train_loop(
tokenizer: Union[AutoTokenizer, None] = None,
accumulate_steps: int = 1,
silence_progress_bars: bool = True,
torch_dtype: "torch.dtype" = torch.float32,
) -> None:
"""Execute the core training logic for training the prompt vectors on the frozen model.
Note that this is done by reference.
Expand Down Expand Up @@ -1124,6 +1125,8 @@ def _execute_train_loop(
Number of steps to use for gradient accumulation. Default: 1.
silence_progress_bars: bool
Silences TQDM progress bars. Default: True
torch_dtype: torch.dtype
Dtype to be used for training. Default: torch.float32
"""
optimizer = AdamW(params=model.parameters(), lr=learning_rate)
lr_scheduler = get_linear_schedule_with_warmup(
Expand All @@ -1135,17 +1138,24 @@ def _execute_train_loop(
# Enable gradient checkpointing
model.gradient_checkpointing_enable()

if torch_dtype == torch.float16:
mixed_precision = "fp16"
elif torch_dtype == torch.bfloat16 and torch.cuda.is_bf16_supported():
mixed_precision = "bf16"
else:
mixed_precision = "no"

accelerator = Accelerator(
gradient_accumulation_steps=accumulate_steps,
device_placement=True,
mixed_precision='bf16',
mixed_precision=mixed_precision,
)

# Disable cache for training
model.config.use_cache=False

# model.to(accelerator.device, torch.bfloat16)
breakpoint()
# Below would send all the data and model to
# configured device and convert them to required dtypes
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler,
)
Expand All @@ -1159,21 +1169,16 @@ def _execute_train_loop(
tqdm_loader.set_description("Epoch: {}".format(epoch))

# TODO Can this dict comprehension always replace "batch.to(device)" for us?
# TODO: Try catching the error using torch.cuda.OutOfMemoryError
# batch = {k: v.to(device) for k, v in batch.items()}
breakpoint()
try:

with accelerator.accumulate(model):
breakpoint()
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
except torch.cuda.OutOfMemoryError as ex:
except torch.cuda.OutOfMemoryError:
error("<NLP07175292E>", MemoryError("Not enough memory available for training!"))

log.info("epoch %s: %s", epoch, loss)
Expand Down
3 changes: 1 addition & 2 deletions caikit_nlp/resources/pretrained_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,7 @@ def bootstrap(
model = cls.MODEL_TYPE.from_pretrained(
model_name,
local_files_only=not get_config().allow_downloads,
# FIXME: Seems like torch.dtype isn't working correctly
torch_dtype=torch.bfloat16,
torch_dtype=torch_dtype,
**kwargs,
)
log.debug4("Model Details: %s", model)
Expand Down
2 changes: 1 addition & 1 deletion examples/run_peft_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None:
verbalizer=dataset_info.verbalizer,
silence_progress_bars=not args.verbose,
accumulate_steps=args.accumulate_steps,
torch_dtype="bfloat16"
torch_dtype="float16"
)
model.save(args.output_dir, save_base_model=not args.prompt_only)
print_colored("[Training Complete]")

0 comments on commit 77ab9bd

Please sign in to comment.