Skip to content

Commit

Permalink
add best relative positional encoding for AST
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 12, 2023
1 parent 598aa34 commit 6b63067
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 31 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.T
- [x] wrap mulan with mulan wrapper and quantize the output, project to audiolm dimensions
- [x] modify audiolm to accept conditioning embeddings, optionally take care of different dimensions through a separate projection
- [x] audiolm and mulan goes into musiclm and generate, filter with mulan
- [x] give dynamic positional bias to self attention in AST

- [ ] give dynamic positional bias to self attention in AST
- [ ] add a version of mulan to <a href="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/mlfoundations/open_clip">open clip</a>
- [ ] set all the proper spectrogram hyperparameters

Expand Down Expand Up @@ -189,6 +189,17 @@ music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.T
}
```

```bibtex
@misc{liu2021swin,
title = {Swin Transformer V2: Scaling Up Capacity and Resolution},
author = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
year = {2021},
eprint = {2111.09883},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

*The only truth is music.* - Jack Kerouac

*Music is the universal language of mankind.* - Henry Wadsworth Longfellow
96 changes: 67 additions & 29 deletions musiclm_pytorch/musiclm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
def forward(
self,
x,
rel_pos_bias = None,
mask = None
):
b, n, _, device = *x.shape, x.device
Expand All @@ -158,6 +159,9 @@ def forward(

sim = einsum('b h i d, b h j d -> b h i j', q, k)

if exists(rel_pos_bias):
sim = sim + rel_pos_bias

if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
Expand Down Expand Up @@ -202,35 +206,19 @@ def __init__(
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
]))

def forward(self, x, mask = None):
def forward(
self,
x,
rel_pos_bias = None,
mask = None
):

for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = attn(x, rel_pos_bias = rel_pos_bias, mask = mask) + x
x = ff(x) + x

return x

# Patch Dropout - https://arxiv.org/abs/2208.07220

class PatchDropout(nn.Module):
def __init__(self, prob):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob

def forward(self, x, force_keep_all = False):
if not self.training or self.prob == 0. or force_keep_all:
return x

b, n, _, device = *x.shape, x.device

batch_indices = torch.arange(b, device = device)
batch_indices = rearrange(batch_indices, '... -> ... 1')
num_patches_keep = max(1, int(n * (1 - self.prob)))
patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices

return x[batch_indices, patch_indices_keep]

# Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778

def pair(t):
Expand Down Expand Up @@ -302,9 +290,30 @@ def __init__(

self.norm = LayerNorm(dim)

self.patch_dropout = PatchDropout(patch_dropout_prob)
# patch dropout

self.patch_dropout_prob = patch_dropout_prob

# 2d dynamic positional bias

mlp_hidden_dim = dim // 4

self.dynamic_pos_bias_mlp = nn.Sequential(
nn.Linear(2, mlp_hidden_dim),
nn.SiLU(),
nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
nn.SiLU(),
nn.Linear(mlp_hidden_dim, heads),
Rearrange('b i j h -> b h i j')
)

def forward(
self,
x,
force_no_patch_dropout = False
):
batch, device = x.shape[0], x.device

def forward(self, x):
x = self.spec(x)

if self.training:
Expand All @@ -326,17 +335,46 @@ def forward(self, x):

x = self.to_patch_tokens(x)

# get number of patches along height and width

num_patch_height, num_patch_width = x.shape[-2:]

# get 2d relative positions

grid = torch.stack(torch.meshgrid(
torch.arange(num_patch_height, device = device),
torch.arange(num_patch_width, device = device)
, indexing = 'ij'), dim = -1)

grid = repeat(grid, '... c -> b (...) c', b = batch)

# 2d sinusoidal positional embedding

x = x + posemb_sincos_2d(x)

# attention, what else

x = rearrange(x, 'b ... c -> b (...) c')

x = self.patch_dropout(x)
# patch dropout

if self.training and self.patch_dropout_prob > 0. and not force_no_patch_dropout:
n, device = x.shape[1], x.device

batch_indices = torch.arange(batch, device = device)
batch_indices = rearrange(batch_indices, '... -> ... 1')
num_patches_keep = max(1, int(n * (1 - self.patch_dropout_prob)))
patch_indices_keep = torch.randn(batch, n, device = device).topk(num_patches_keep, dim = -1).indices

x = x[batch_indices, patch_indices_keep]
grid = grid[batch_indices, patch_indices_keep]

# 2d relative positional bias

rel_dist = rearrange(grid, 'b i c -> b i 1 c') - rearrange(grid, 'b j c -> b 1 j c')
rel_pos_bias = self.dynamic_pos_bias_mlp(rel_dist.float())

# attention, what else

x = self.transformer(x)
x = self.transformer(x, rel_pos_bias = rel_pos_bias)

# final global average and norm (most recent papers show this is superior to CLS token)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'musiclm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.19',
version = '0.0.20',
license='MIT',
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
author = 'Phil Wang',
Expand Down

6 comments on commit 6b63067

@ukemamaster
Copy link

@ukemamaster ukemamaster commented on 6b63067 Feb 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucidrains Upgrading to this version i always get memory allocation error. Even for batch size = 2

RuntimeError: CUDA out of memory. Tried to allocate 26.82 GiB (GPU 0; 23.65 GiB total capacity; 1.60 GiB already allocated; 21.01 GiB free; 1.68 GiB reserved in total by PyTorch) 
If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

This is my training code:

import torch
from musiclm_pytorch import MusicLM, MuLaNTrainer
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer, MuLaNEmbedQuantizer

from dataset_class.mulan_dataset import MuLaNDataset

DATA_PATH = 'music_generation/data'
LIST_PATH = DATA_PATH + '/lst'
WAV_PATH = DATA_PATH + '/wavs'

audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    max_seq_len = 128
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

dataset = MuLaNDataset(LIST_PATH, WAV_PATH)

trainer = MuLaNTrainer(
    mulan = mulan,
    dataset = dataset,
    num_train_steps = 500,
    batch_size = 4,
    save_model_every = 100,
    force_clear_prev_results = False
)

trainer.train()

@Lunariz
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the same issue as @ukemamaster:

  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\audiolm_pytorch\trainer.py", line 648, in train
    logs = self.train_step()
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\audiolm_pytorch\trainer.py", line 604, in train_step
    loss = self.train_wrapper(**data_kwargs, return_loss = True)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\audiolm_pytorch\audiolm_pytorch.py", line 1058, in forward
    text_embeds = self.audio_conditioner(wavs = raw_wave, namespace = 'semantic')
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 620, in forward
    latents = self.mulan.get_audio_latents(wavs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 500, in get_audio_latents
    audio_embeds = self.audio(wavs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 373, in forward
    rel_pos_bias = self.dynamic_pos_bias_mlp(rel_dist.float())
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\container.py", line 204, in forward
    input = module(input)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 351.13 GiB (GPU 0; 24.00 GiB total capacity; 17.64 GiB already allocated; 1.43 GiB free; 17.75 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Training the Mulan model works fine for me, but it's training the Semantic model that breaks.
I've also tried running it with smaller hyperparams (replace all 1024 -> 256, 512 -> 128, depth 6 -> depth 1). I get the following error instead:

File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\audiolm_pytorch\trainer.py", line 648, in train
    logs = self.train_step()
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\audiolm_pytorch\trainer.py", line 604, in train_step
    loss = self.train_wrapper(**data_kwargs, return_loss = True)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\audiolm_pytorch\audiolm_pytorch.py", line 1058, in forward
    text_embeds = self.audio_conditioner(wavs = raw_wave, namespace = 'semantic')
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 620, in forward
    latents = self.mulan.get_audio_latents(wavs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 500, in get_audio_latents
    audio_embeds = self.audio(wavs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 377, in forward
    x = self.transformer(x, rel_pos_bias = rel_pos_bias)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 217, in forward
    x = attn(x, rel_pos_bias = rel_pos_bias, mask = mask) + x
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Lukas\AppData\Local\Programs\Python\Python310\lib\site-packages\musiclm_pytorch\musiclm_pytorch.py", line 163, in forward
    sim = sim + rel_pos_bias
RuntimeError: The size of tensor a (212) must match the size of tensor b (6784) at non-singleton dimension 3

It's specifically this commit (0.0.20) that breaks things - both sets of hyperparams train without a hitch on 0.0.19.

@ukemamaster
Copy link

@ukemamaster ukemamaster commented on 6b63067 Feb 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lunariz You said MuLaN training is fine. May i have a look at your code? i want to know where i am mesing it?
Also, which data are you training MuLaN with?

@Lunariz
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lunariz You said MuLaN training is fine. May i have a look at your code? i want to know where i am mesing it? Also, which data are you training MuLaN with?

I'm still using a mock dataset. My code is adapted from this comment

audio_transformer = AudioSpectrogramTransformer(
    dim = dim_small,
    depth = depth,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = dim_small,
    depth = depth,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

from torch.utils.data import Dataset

class MockTextAudioDataset(Dataset):
    def __init__(self, length = 100, audio_length = 320 * 32):
        super().__init__()
        self.audio_length = audio_length
        self.len = length

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        from random import randrange
        mock_audio = torch.randn(randrange(self.audio_length // 2, self.audio_length))
        mock_text = torch.randint(0, 12, (256,)).long()
        return mock_text, mock_audio



saved_model_path = "./{}/mulan.pt".format(modelsize)

if os.path.exists(saved_model_path):
    print("Loading mulan model")
    mulan = torch.load(saved_model_path)
else:
    print("Training mulan model")
    trainer = MuLaNTrainer(
        mulan = mulan,
        dataset = MockTextAudioDataset(),
        batch_size = 4,
        force_clear_prev_results = True
    )
    trainer.train()
    
    torch.save(mulan, saved_model_path)

@ukemamaster
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I am trying with the test dataset (5.5k samples) that they used in the original paper.

@ukemamaster
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the model architecture, the GPU gets out of memory. However reducing the amount of audio input from the dataset solves the problem.
In my case, it runs fine for maximum 3 seconds, 24kHz audio input, on 24GB RTX.

Please sign in to comment.