-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
02f50ff
commit caeb95d
Showing
62 changed files
with
4,603 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,117 @@ | ||
# MusCaps: Generating Captions for Music Audio | ||
[Ilaria Manco](https://ilariamanco.com/)<sup>1</sup> <sup>2</sup>, | ||
[Emmanouil Benetos](http:https://www.eecs.qmul.ac.uk/~emmanouilb/)<sup>1</sup>, | ||
[Elio Quinton](https://scholar.google.com/citations?user=IaciybgAAAAJ)<sup>2</sup>, | ||
[Gyorgy Fazekas](http:https://www.eecs.qmul.ac.uk/~gyorgyf/about.html)<sup>1</sup> <br> | ||
<sup>1</sup> Queen Mary University of London, <sup>2</sup> Universal Music Group | ||
|
||
This repository hosts the official PyTorch implementation of the IJCNN 2021 paper "MusCaps: Generating Captions for Music Audio". | ||
<p align="center"> | ||
<img src="muscaps.png" width="500"> | ||
</p align="center"> | ||
|
||
**Code coming soon** | ||
This repository is the official implementation of ["MusCaps: Generating Captions for Music Audio"](https://arxiv.org/abs/2104.11984) (IJCNN 2021). In this work, we propose an encoder-decoder model to generate natural language descriptions of music audio. We provide code to train our model on any dataset of (audio, caption) pairs, together with code to evaluate the generated descriptions on a set of automatic metrics (BLEU, METEOR, ROUGE, CIDEr, SPICE, SPIDEr). | ||
|
||
## Setup | ||
The code was developed in Python 3.7 on Linux CentOS 7 and training was carried out on an RTX 2080 Ti GPU. Other GPUs and platforms have not been fully tested. | ||
|
||
Clone the repo | ||
```bash | ||
git clone https://github.com/ilaria-manco/muscaps | ||
cd muscaps | ||
``` | ||
|
||
You'll need to have the `libsndfile` library installed. All other requirements, including the code package, can be installed with | ||
```bash | ||
pip install -r requirements.txt | ||
pip install -e . | ||
``` | ||
|
||
## Project structure | ||
|
||
``` | ||
root | ||
├─ configs # Config files | ||
│ ├─ datasets | ||
│ ├─ models | ||
│ └─ default.yaml | ||
├─ data # Folder to save data (input data, pretrained model weights, etc.) | ||
│ ├─ audio_encoders | ||
│ ├─ datasets | ||
│ │ └─ dataset_name | ||
| └── ... | ||
├─ muscaps | ||
| ├─ caption_evaluation_tools # Translation metrics eval on audio captioning | ||
│ ├─ datasets # Dataset classes | ||
│ ├─ models # Model code | ||
│ ├─ modules # Model components | ||
│ ├─ scripts # Python scripts for training, evaluation etc. | ||
│ ├─ trainers # Trainer classes | ||
│ └─ utils # Utils | ||
└─ save # Saved model checkpoints, logs, configs, predictions | ||
└─ experiments | ||
├── experiment_id1 | ||
└── ... | ||
``` | ||
|
||
## Dataset | ||
The datasets used in our experiments is private and cannot be shared, but details on how to prepare an equivalent music captioning dataset are provided in the [data README](data/README.md). | ||
|
||
## Pre-trained audio feature extractors | ||
For the audio feature extraction component, MusCaps uses CNN-based audio tagging models like [musicnn](https://github.com/jordipons/musicnn). In our experiments, we use [@minzwon](https://github.com/minzwon)'s implementation and pre-trained models, which you can download from [the official repo](https://github.com/minzwon/sota-music-tagging-models/). For example, to obtain the weights for the [HCNN model](https://ieeexplore.ieee.org/abstract/document/9053669) trained on the [MagnaTagATune dataset](http:https://mirg.city.ac.uk/codeapps/the-magnatagatune-dataset), run the following commands | ||
|
||
```bash | ||
mkdir data/audio_encoders | ||
cd data/audio_encoders/ | ||
wget https://github.com/minzwon/sota-music-tagging-models/raw/master/models/mtat/hcnn/best_model.pth | ||
mv best_model.pth mtt_hcnn.pth | ||
``` | ||
|
||
## Training | ||
Dataset, model and training configurations are set in the respective `yaml` files in [`configs`](configs). Some of the fields can be overridden by arguments in the CLI (for more details on this, refer to the [training script](muscaps/scripts/train.py)). | ||
|
||
To train the model with the default configs, simply run | ||
|
||
```bash | ||
cd muscaps/scripts/ | ||
python train.py <baseline/attention> --feature_extractor <musicnn/hcnn> --pretrained_model <msd/mtt> --device_num <gpu_number> | ||
``` | ||
|
||
This will generate an `experiment_id` and create a new folder in `save/experiments` where the output will be saved. | ||
|
||
If you wish to resume training from a saved checkpoint, run | ||
|
||
```bash | ||
python train.py <baseline/attention> --experiment_id <experiment_id> --device_num <gpu_number> | ||
``` | ||
|
||
## Evaluation | ||
|
||
To evaluate a model saved under `<experiment_id>` on the captioning task, run | ||
|
||
```bash | ||
cd muscaps/scripts/ | ||
python caption.py <experiment_id> --metrics True | ||
``` | ||
|
||
## Cite | ||
|
||
```bib | ||
@misc{manco2021muscaps, | ||
title={MusCaps: Generating Captions for Music Audio}, | ||
author={Ilaria Manco and Emmanouil Benetos and Elio Quinton and Gyorgy Fazekas}, | ||
year={2021}, | ||
eprint={2104.11984}, | ||
archivePrefix={arXiv} | ||
} | ||
``` | ||
|
||
## Acknowledgements | ||
This repo reuses some code from the following repos: | ||
* [sota-music-tagging-models](https://github.com/minzwon/sota-music-tagging-models) by [@minzwon](https://github.com/minzwon) | ||
* [caption-evaluation-tools](https://github.com/audio-captioning/caption-evaluation-tools) by [@audio-captioning](https://github.com/audio-captioning) | ||
* [mmf](https://github.com/facebookresearch/mmf) by [@facebookresearch](https://github.com/facebookresearch) | ||
* [a-PyTorch-Tutorial-to-Image-Captioning](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/) by [@sgrvinod](https://github.com/sgrvinod/) | ||
* [allennlp](https://github.com/allenai/allennlp) by [@allenai](https://github.com/allenai/) | ||
|
||
## Contact | ||
If you have any questions, please get in touch: [[email protected]]([email protected]). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
dataset_config: | ||
dataset_name: "audiocaption" | ||
data_dir: ${env.data_root}/datasets/${dataset_config.dataset_name} | ||
# Caption configs | ||
captions: | ||
vocab: | ||
min_count: 2 | ||
type: pretrained | ||
embedding_name: glove.6B.300d | ||
vocab_file: ${data_dir}/vocab.txt | ||
# Audio configs | ||
audio: | ||
sr: 16000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Training configuration | ||
training: | ||
# Experiment id to use in logging and checkpoints | ||
experiment_id: null | ||
# Maximum number of epochs | ||
epochs: 300 | ||
# Device to be used to train the model. Can be "cuda" or "cpu" | ||
device: cuda | ||
# Optimizer configuration | ||
optimizer: adam | ||
# Learning rate | ||
lr: 0.0001 | ||
patience: 5 | ||
batch_size: 8 | ||
# Number of workers to be used in dataloader | ||
num_workers: 4 | ||
# Whether to shuffle in dataloader | ||
shuffle: true | ||
# Whether to pin memory in dataloader | ||
pin_memory: false | ||
# Whether gradient clipping should be applied | ||
clip_gradients: true | ||
|
||
# Environment configuration | ||
env: | ||
# Base directory of the repo, populated when config is loaded | ||
base_dir: null | ||
# Directory for storing datasets and models | ||
data_root: ${env.base_dir}/data | ||
# Directory for experiments, logs, output samples etc. | ||
save_dir: ${env.base_dir}/save | ||
# Directory for saving models, logs and checkpoints for each experiment | ||
experiments_dir: ${env.save_dir}/experiments |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
model_config: | ||
model_name: cnn_attention_lstm | ||
fusion: early | ||
loss: cross_entropy | ||
# Feature Extraction | ||
finetune: false | ||
pool_type: avg | ||
feature_extractor_type: hcnn | ||
pretrained_version: mtt | ||
feature_extractor_path: ${env.data_root}/audio_encoders/${model_config.pretrained_version}_${model_config.feature_extractor_type}.pth | ||
# Encoder | ||
word_embed_dim: 300 | ||
hidden_dim_encoder: 256 | ||
vocab_size: 0 | ||
# Attention | ||
attention_dim: 256 | ||
attention_type: mlp | ||
attention_dropout: 0.25 | ||
# Decoder | ||
dropout_decoder: 0.25 | ||
hidden_dim_decoder: 256 | ||
# Inference | ||
inference_type: greedy | ||
beam_size: 2 | ||
max_caption_len: 22 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
model_config: | ||
model_name: cnn_lstm_caption | ||
fusion: early | ||
loss: cross_entropy | ||
# Feature Extraction | ||
finetune: false | ||
pool_type: avg | ||
feature_extractor_type: hcnn | ||
pretrained_version: mtt | ||
feature_extractor_path: ${env.data_root}/audio_encoders/${model_config.pretrained_version}_${model_config.feature_extractor_type}.pth | ||
# Encoder | ||
word_embed_dim: 300 | ||
hidden_dim_encoder: 256 | ||
vocab_size: 0 | ||
# Decoder | ||
dropout_decoder: 0.25 | ||
hidden_dim_decoder: 256 | ||
# Inference | ||
inference_type: beam_search | ||
beam_size: 2 | ||
max_caption_len: 22 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Dataset Preparation | ||
|
||
MusCaps can be trained on datasets of (audio, caption) pairs, organised as follows | ||
|
||
``` | ||
dataset_name | ||
├── audio | ||
│ ├── track_1.npy | ||
│ ├── track_2.npy | ||
| └── ... | ||
├── dataset_train.json | ||
├── dataset_val.json | ||
└── dataset_test.json | ||
``` | ||
|
||
### Captions | ||
Prepare a json file for each data split (`train`, `val`, `test`) with the annotations containing the following fields | ||
|
||
```json | ||
{ | ||
"track_id": "track_1", | ||
"caption": | ||
{ | ||
"raw": "First caption!", | ||
"tokens": ["first", "caption"] | ||
} | ||
} | ||
``` | ||
and place it in the `data/datasets/<dataset_name>` folder. A toy example is provided in [`data/datasets/audiocaption`](datasets/audiocaption). To preprocess the raw caption text and obtain the tokens, you'll need to lower case and remove punctuation. You may also want to ensure all captions have a suitable length (e.g. between 3 and 22 tokens). | ||
|
||
### Audio | ||
Audio files should be preprocessed and stored as `numpy` arrays in `data/datasets/<dataset_name>/audio/`. Each file name should correspond to the `track_id` field in the annotations (e.g `track_id.npy`). A preprocessing example for wav files is provided below: | ||
|
||
```python | ||
import os | ||
import librosa | ||
import numpy as np | ||
|
||
audio_dir = "data/datasets/audiocaption/audio" | ||
for audio_file in os.listdir(audio_dir): | ||
audio_path = os.path.join(audio_dir, audio_file) | ||
audio, sr = librosa.load(audio_path, 16000) | ||
array_path = audio_path.replace("wav", "npy") | ||
np.save(open(array_path, 'wb'), audio) | ||
``` | ||
|
||
If the audio file names do not correspond to the track IDs, you'll need an additional field in the annotation files containing the file paths. In this case, remember to edit the code in [`muscaps/datasets/`](../muscaps/datasets/) to point to the corresponding file paths when loading the data. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
[ | ||
{ | ||
"track_id": "track_13", | ||
"caption": | ||
{ | ||
"raw": "First caption that describes this amazing track!", | ||
"tokens": ["first", "caption", "that", "describes", "this", "amazing", "track"] | ||
} | ||
}, | ||
{ | ||
"track_id": "track_14", | ||
"caption": | ||
{ | ||
"raw": "Second caption that describes this amazing track!", | ||
"tokens": ["second", "caption", "that", "describes", "this", "amazing", "track"] | ||
} | ||
}, | ||
{ | ||
"track_id": "track_15", | ||
"caption": | ||
{ | ||
"raw": "Third caption that describes this amazing track!", | ||
"tokens": ["third", "caption", "that", "describes", "this", "amazing", "track"] | ||
} | ||
}, | ||
{ | ||
"track_id": "track_16", | ||
"caption": | ||
{ | ||
"raw": "Fourth caption that describes this amazing track!", | ||
"tokens": ["fourth", "caption", "that", "describes", "this", "amazing", "track"] | ||
} | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
[ | ||
{ | ||
"track_id": "track_1", | ||
"caption": | ||
{ | ||
"raw": "First caption that describes this amazing track!", | ||
"tokens": ["first", "caption", "that", "describes", "this", "amazing", "track"] | ||
} | ||
}, | ||
{ | ||
"track_id": "track_2", | ||
"caption": | ||
{ | ||
"raw": "Second caption that describes this amazing track!", | ||
"tokens": ["second", "caption", "that", "describes", "this", "amazing", "track"] | ||
} | ||
}, | ||
{ | ||
"track_id": "track_3", | ||
"caption": | ||
{ | ||
"raw": "Third caption that describes this amazing track!", | ||
"tokens": ["third", "caption", "that", "describes", "this", "amazing", "track"] | ||
} | ||
}, | ||
{ | ||
"track_id": "track_4", | ||
"caption": | ||
{ | ||
"raw": "Fourth caption that describes this amazing track!", | ||
"tokens": ["fourth", "caption", "that", "describes", "this", "amazing", "track"] | ||
} | ||
}, | ||
{ | ||
"track_id": "track_5", | ||
"caption": | ||
{ | ||
"raw": "Fifth caption that describes this amazing track!", | ||
"tokens": ["fifth", "caption", "that", "describes", "this", "amazing", "track"] | ||
} | ||
}, | ||
{ | ||
"track_id": "track_6", | ||
"caption": | ||
{ | ||
"raw": "Sixth caption that describes this amazing track!", | ||
"tokens": ["sixth", "caption", "that", "describes", "this", "amazing", "track"] | ||
} | ||
}, | ||
{ | ||
"track_id": "track_7", | ||
"caption": | ||
{ | ||
"raw": "Seventh caption that describes this amazing track!", | ||
"tokens": ["seventh", "caption", "that", "describes", "this", "amazing", "track"] | ||
} | ||
}, | ||
{ | ||
"track_id": "track_8", | ||
"caption": | ||
{ | ||
"raw": "Eighth caption that describes this amazing track!", | ||
"tokens": ["eighth", "caption", "that", "describes", "this", "amazing", "track"] | ||
} | ||
} | ||
] |
Oops, something went wrong.