Skip to content

Commit

Permalink
🥅 Add exception handling for OOM PT
Browse files Browse the repository at this point in the history
Signed-off-by: gkumbhat <[email protected]>
  • Loading branch information
gkumbhat committed Sep 18, 2023
1 parent 5af73d1 commit 50bf6d9
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,10 +535,16 @@ def train(
base_model.model.config.d_model = 1024

peft_model = get_peft_model(base_model.model, peft_config)

# FIXME:
del base_model._model
torch.cuda.empty_cache()


# Convert our Peft model (not just the underlying
# transformers model) to the right underlying type.
device = cls._get_device(device)
cls.convert_peft_model_to_type(device, peft_model, torch_dtype)
# cls.convert_peft_model_to_type(device, peft_model, torch_dtype)
breakpoint()
cls._execute_train_loop(
peft_model,
Expand Down Expand Up @@ -1131,6 +1137,10 @@ def _execute_train_loop(
device_placement=True,
mixed_precision='bf16',
)

# Disable cache for training
model.config.use_cache=False

# model.to(accelerator.device, torch.bfloat16)
breakpoint()
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
Expand All @@ -1143,22 +1153,25 @@ def _execute_train_loop(
tqdm_loader = tqdm(train_dataloader, disable=silence_progress_bars)
for batch in tqdm_loader:

optimizer.zero_grad()

tqdm_loader.set_description("Epoch: {}".format(epoch))

# TODO Can this dict comprehension always replace "batch.to(device)" for us?
# TODO: Try catching the error using torch.cuda.OutOfMemoryError
# batch = {k: v.to(device) for k, v in batch.items()}
breakpoint()
with accelerator.accumulate(model):
breakpoint()
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
try:

with accelerator.accumulate(model):
breakpoint()
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
except torch.cuda.OutOfMemoryError as ex:
error("<NLP07175292E>", MemoryError("Not enough memory available for training!"))

log.info("epoch %s: %s", epoch, loss)
if eval_dataloader is not None:
Expand Down

0 comments on commit 50bf6d9

Please sign in to comment.