Skip to content

Commit

Permalink
make sure regular contrastive loss supports distributed data parallel…
Browse files Browse the repository at this point in the history
…. work towards completing the sigmoid contrastive loss
  • Loading branch information
lucidrains committed Sep 5, 2023
1 parent 80188c5 commit 0876027
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 17 deletions.
51 changes: 51 additions & 0 deletions musiclm_pytorch/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
from torch.autograd import Function
import torch.distributed as distributed

from einops import rearrange

# distributed helpers

def all_gather_variable_dim(t, dim = 0, sizes = None):
device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size()

if not exists(sizes):
size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
distributed.all_gather(sizes, size)
sizes = torch.stack(sizes)

max_size = sizes.amax().item()
padded_t = pad_dim_to(t, max_size, dim = dim)

gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
distributed.all_gather(gathered_tensors, padded_t)

gathered_tensor = torch.cat(gathered_tensors, dim = dim)
seq = torch.arange(max_size, device = device)

mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
mask = rearrange(mask, 'i j -> (i j)')
seq = torch.arange(mask.shape[-1], device = device)
indices = seq[mask]

gathered_tensor = gathered_tensor.index_select(dim, indices)

return gathered_tensor, sizes

class AllGather(Function):
@staticmethod
def forward(ctx, x, dim, sizes):
assert distributed.is_initialized() and distributed.get_world_size() > 1
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
ctx.batch_sizes = batch_sizes.tolist()
ctx.dim = dim
return x, batch_sizes

@staticmethod
def backward(ctx, grads, _):
batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None

all_gather = AllGather.apply
48 changes: 32 additions & 16 deletions musiclm_pytorch/musiclm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from audiolm_pytorch import AudioLM
from audiolm_pytorch.utils import AudioConditionerBase

import torch.distributed as dist
from musiclm_pytorch.distributed import all_gather

from x_clip.tokenizer import tokenizer
from vector_quantize_pytorch import ResidualVQ

Expand Down Expand Up @@ -263,15 +266,27 @@ def __init__(
self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
self.decoupled_contrastive_learning = decoupled_contrastive_learning

self.needs_all_gather = dist.is_initialized() and dist.get_world_size() > 1

@property
def device(self):
return next(self.parameters()).device

def forward(self, sims):
batch = sims.shape[-1]
def forward(self, audio_latents, text_latents):
if audio_latents.ndim == 2:
audio_latents = rearrange(audio_latents, '... -> 1 ...')

if text_latents.ndim == 2:
text_latents = rearrange(text_latents, '... -> 1 ...')

if sims.ndim == 2:
sims = rearrange(sims, 'i j -> 1 i j')
batch = audio_latents.shape[1]

if self.needs_all_gather:
latents = torch.stack((audio_latents, text_latents))
latents = all_gather(latents, 2, None)
audio_latents, text_latents = latents

sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents)

sims = sims * self.temperatures.exp()

Expand Down Expand Up @@ -309,11 +324,17 @@ def __init__(
def device(self):
return next(self.parameters()).device

def forward(self, sims):
if sims.ndim == 2:
sims = rearrange(sims, 'i j -> 1 i j')
def forward(self, audio_latents, text_latents):
if audio_latents.ndim == 2:
audio_latents = rearrange(audio_latents, '... -> 1 ...')

if text_latents.ndim == 2:
text_latents = rearrange(text_latents, '... -> 1 ...')

n = audio_latents.shape[1]

sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents)

n = sims.shape[-1]
sims = sims * self.temperatures.exp() + self.bias
labels = 2 * rearrange(torch.eye(n), 'i j -> 1 i j') - torch.ones_like(sims)

Expand Down Expand Up @@ -643,9 +664,7 @@ def forward(self, *, audio_layers, text_layers):
text_latents = einsum('l b d, l d e -> l b e', text_embeds, self.text_latent_weight) + self.text_latent_bias
text_latents = l2norm(text_latents)

cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents)

return self.contrast(cosine_sims)
return self.contrast(audio_latents, text_latents)

# main classes

Expand Down Expand Up @@ -743,14 +762,11 @@ def forward(
if return_similarities:
return einsum('i d, i d -> i', audio_latents, text_latents)

cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents)

assert cosine_sim.shape[0] == cosine_sim.shape[1], 'batch sizes for audio and text are not equal'

if return_pairwise_similarities:
cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents)
return cosine_sim

cl_loss = self.contrast(cosine_sim)
cl_loss = self.contrast(audio_latents, text_latents)

if not exists(self.multi_layer_contrastive_learning):
return cl_loss
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'musiclm-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.2',
version = '0.2.3',
license='MIT',
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
author = 'Phil Wang',
Expand Down

0 comments on commit 0876027

Please sign in to comment.