Skip to content
/ DiffRoll Public

PyTorch implementation of DiffRoll, a diffusion-based generative automatic music transcription (AMT) model

License

Notifications You must be signed in to change notification settings

sony/DiffRoll

Repository files navigation

Table of Content

Installation

This repo is developed using python==3.8.10, so it is recommended to use python>=3.8.10.

To install all dependencies

pip install -r requirements.txt

Training

Supervised training

python train_spec_roll.py gpus=[0] model.args.kernel_size=9 model.args.spec_dropout=0.1 dataset=MAESTRO dataloader.train.num_workers=4 epochs=2500 download=True
  • gpus sets which GPU to use. gpus=[k] means device='cuda:k', gpus=2 means DistributedDataParallel (DDP) is used with two GPUs.
  • model.args.kernel_size sets the kernel size for the ResNet layers in DiffRoll. model.args.kernel_size=9 performs the best according to our experiments.
  • model.args.spec_dropout sets the dropout rate ($p$ in the paper)
  • dataset sets the dataset to be trained on. Can be MAESTRO or MAPS.
  • dataloader.train.num_workers sets the number of workers for train loader.
  • download should be set to True if you are running the script for the first time to download and setup the dataset automatically. You can set it to False if you already have the dataset downloaded.

The checkpoints and training logs are avaliable at outputs/YYYY-MM-DD/HH-MM-SS/.

To check the progress of training using TensorBoard, you can use the command below

tensorboard --logdir='./outputs'

Unsupervised pretraining

Step 1: Pretraining on MAESTRO using only piano rolls

python train_spec_roll.py gpus=[0] model.args.kernel_size=9 model.args.spec_dropout=1 dataset=MAESTRO dataloader.train.num_workers=4 epochs=2500
  • model.args.spec_dropout sets the dropout rate ($p$ in the paper). When it is set to 1, it means no spectrograms will be used (all spectrograms dropped to -1)
  • other arguments are same as Supervised Training.

The pretrained checkpoints are avaliable at outputs/YYYY-MM-DD/HH-MM-SS/ClassifierFreeDiffRoll/version_1/checkpoints.

After this, you can choose one of the options (2A, 2B, or 2C) to continue training below.

Step 2

Choose one of the options below (A, B, or C).

Option A: pre-DiffRoll (p=0.1)