Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to PretrainDataset and FinetuneDataset #73

Merged
merged 4 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add finetuning tests and fix resume
  • Loading branch information
loubbrad committed Dec 5, 2023
commit 9efce78cb5d3abd49c72877c500f9f69b593e7a9
6 changes: 3 additions & 3 deletions aria/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def _build_epoch(_save_path, _midi_dataset):
"as intended when building different epochs."
)
for idx in range(num_epochs):
logger.info(f"Building epoch {idx+1}/{num_epochs}...")
logger.info(f"Building epoch {idx}/{num_epochs - 1}...")
_build_epoch(
_save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"),
_midi_dataset=midi_dataset,
Expand Down Expand Up @@ -674,7 +674,7 @@ def __len__(self):

# Do nothing in this case
def init_epoch(self, idx: int | None = None):
self.logger.info(f"Successful initiated epoch {idx} - no changes")
self.logger.info(f"Successful initiated epoch {idx}")

@classmethod
def build(
Expand Down Expand Up @@ -740,7 +740,7 @@ def _build(_midi_dataset):
}
)
logger.info(
f"Building tokenized dataset with config: "
f"Building FinetuningDataset with config: "
f"tokenizer_name=tokenizer.name"
f"max_seq_len={max_seq_len} "
f"stride_len={stride_len}"
Expand Down
90 changes: 50 additions & 40 deletions aria/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,15 @@ def get_finetune_dataloaders(
)

if apply_aug:
train_dataset.set_transform(
[
tokenizer.export_chord_mixup(),
tokenizer.export_velocity_aug(1),
tokenizer.export_pitch_aug(5),
tokenizer.export_tempo_aug(0.2),
]
)
logging.error("Remember to reenable data aug")
# train_dataset.set_transform(
# [
# tokenizer.export_chord_mixup(),
# tokenizer.export_velocity_aug(1),
# tokenizer.export_pitch_aug(5),
# tokenizer.export_tempo_aug(0.2),
# ]
# )

train_dataloader = DataLoader(
train_dataset,
Expand Down Expand Up @@ -325,6 +326,7 @@ def _train(
scheduler: torch.optim.lr_scheduler.LRScheduler = None,
steps_per_checkpoint: int | None = None,
resume_step: int | None = None,
resume_epoch: int | None = None,
):
def profile_flops(dataloader: DataLoader):
def _bench():
Expand Down Expand Up @@ -359,22 +361,15 @@ def _bench():
f"{iters_per_second * total_flops / 1e12} TF/s (not warm)"
)

def make_checkpoint(_accelerator, _epoch: int, _step: int | None = None):
if _step:
checkpoint_dir = os.path.join(
project_dir,
"checkpoints",
f"epoch{_epoch}_step{_step}",
)
else:
checkpoint_dir = os.path.join(
project_dir,
"checkpoints",
f"epoch{_epoch}",
)
def make_checkpoint(_accelerator, _epoch: int, _step: int):
checkpoint_dir = os.path.join(
project_dir,
"checkpoints",
f"epoch{_epoch}_step{_step}",
)

logger.info(
f"EPOCH {_epoch}/{epochs}: Saving checkpoint - {checkpoint_dir}"
f"EPOCH {_epoch}/{epochs + start_epoch}: Saving checkpoint - {checkpoint_dir}"
)
_accelerator.save_state(checkpoint_dir)

Expand All @@ -393,13 +388,15 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0):
lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"])

model.train()
for step, batch in (
for __step, batch in (
pbar := tqdm(
enumerate(dataloader),
total=len(dataloader),
total=len(dataloader) + _resume_step,
initial=_resume_step,
leave=False,
)
):
step = __step + _resume_step + 1
src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz)
logits = model(src) # (b_sz, s_len, v_sz)
logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss
Expand All @@ -410,17 +407,19 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0):
if len(loss_buffer) > TRAILING_LOSS_STEPS:
loss_buffer.pop(0)
trailing_loss = sum(loss_buffer) / len(loss_buffer)
avg_train_loss = rolling_average(avg_train_loss, loss.item(), step)
avg_train_loss = rolling_average(
avg_train_loss, loss.item(), __step
)

# Logging
logger.debug(
f"EPOCH {_epoch} STEP {_resume_step + step}: "
f"EPOCH {_epoch} STEP {step}: "
f"lr={lr_for_print}, "
f"loss={round(loss.item(), 4)}, "
f"trailing_loss={round(trailing_loss, 4)}, "
f"average_loss={round(avg_train_loss, 4)}"
)
loss_writer.writerow([_epoch, _resume_step + step, loss.item()])
loss_writer.writerow([_epoch, step, loss.item()])
pbar.set_postfix_str(
f"lr={lr_for_print}, "
f"loss={round(loss.item(), 4)}, "
Expand All @@ -436,15 +435,15 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0):
lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0])

if steps_per_checkpoint:
if step % steps_per_checkpoint == 0 and step != 0:
if step % steps_per_checkpoint == 0:
make_checkpoint(
_accelerator=accelerator,
_epoch=epoch,
_epoch=_epoch,
_step=step,
)

logger.info(
f"EPOCH {_epoch}/{epochs}: Finished training - "
f"EPOCH {_epoch}/{epochs + start_epoch}: Finished training - "
f"average_loss={round(avg_train_loss, 4)}"
)

Expand Down Expand Up @@ -472,7 +471,7 @@ def val_loop(dataloader, _epoch: int):

# EPOCH
logger.info(
f"EPOCH {_epoch}/{epochs}: Finished evaluation - "
f"EPOCH {_epoch}/{epochs + start_epoch}: Finished evaluation - "
f"average_loss={round(avg_val_loss, 4)}"
)

Expand All @@ -497,32 +496,42 @@ def val_loop(dataloader, _epoch: int):
epoch_writer = csv.writer(epoch_csv)
epoch_writer.writerow(["epoch", "avg_train_loss", "avg_val_loss"])

if resume_step:
if resume_epoch is not None:
start_epoch = resume_epoch + 1
else:
start_epoch = 0

if resume_step is not None:
assert resume_epoch is not None, "Must provide resume epoch"
logger.info(
f"Resuming training from step {resume_step} - logging as EPOCH 0"
f"Resuming training from step {resume_step} - logging as EPOCH {resume_epoch}"
)
skipped_dataloader = accelerator.skip_first_batches(
dataloader=train_dataloader,
num_batches=resume_step,
)

avg_train_loss = train_loop(
dataloader=skipped_dataloader,
_epoch=0,
_epoch=resume_epoch,
_resume_step=resume_step,
)
avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=0)
avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=resume_epoch)
epoch_writer.writerow([0, avg_train_loss, avg_val_loss])
epoch_csv.flush()
make_checkpoint(_accelerator=accelerator, _epoch=start_epoch, _step=0)

for epoch in range(0, epochs):
for epoch in range(start_epoch, epochs + start_epoch):
train_dataloader.dataset.init_epoch(epoch)
avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch)
avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=epoch)
epoch_writer.writerow([epoch, avg_train_loss, avg_val_loss])
epoch_csv.flush()
train_dataloader.dataset.init_epoch()
make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1)
make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0)

loss_csv.close()
epoch_csv.close()
logging.shutdown()


# NOTE: Any differences observed when resuming training are most likely the
Expand Down Expand Up @@ -574,7 +583,7 @@ def resume_train(
"Please insure that the training config and resume step are set "
"correctly, the script does not currently check that this is the case. "
"If the previous checkpoint was saved at step n, then resume_step "
"should be n+1. If there is a mismatch between the batch size then the "
"should be n. If there is a mismatch between the batch size then the "
"script will resume at the wrong step."
)
logger.info(
Expand Down Expand Up @@ -664,6 +673,7 @@ def resume_train(
scheduler=scheduler,
steps_per_checkpoint=steps_per_checkpoint,
resume_step=resume_step,
resume_epoch=resume_epoch,
)


Expand Down Expand Up @@ -834,7 +844,7 @@ def parse_resume_args():
argp.add_argument("val_data", help="path to val data")
argp.add_argument("-cdir", help="checkpoint dir", type=str, required=True)
argp.add_argument("-rstep", help="resume step", type=int, required=True)
argp.add_argument("-repoch", help="resume step", type=int, required=True)
argp.add_argument("-repoch", help="resume epoch", type=int, required=True)
argp.add_argument("-epochs", help="train epochs", type=int, required=True)
argp.add_argument("-bs", help="batch size", type=int, default=32)
argp.add_argument("-workers", help="number workers", type=int, default=1)
Expand Down
Loading