Skip to content

Commit

Permalink
complete distributed logic for sigmoid contrastive loss
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 6, 2023
1 parent 0876027 commit b5ca8ae
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
18 changes: 18 additions & 0 deletions musiclm_pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,21 @@ def backward(ctx, grads, _):
return grads_by_rank[rank], None, None

all_gather = AllGather.apply

class AllGatherAllReduceGrads(Function):
@staticmethod
def forward(ctx, x, dim, sizes):
assert distributed.is_initialized() and distributed.get_world_size() > 1
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
ctx.batch_sizes = batch_sizes.tolist()
ctx.dim = dim
return x, batch_sizes

@staticmethod
def backward(ctx, grads, _):
distributed.all_reduce(grads)
batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None

all_gather_all_reduce_grads = AllGatherAllReduceGrads.apply
20 changes: 17 additions & 3 deletions musiclm_pytorch/musiclm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from audiolm_pytorch.utils import AudioConditionerBase

import torch.distributed as dist
from musiclm_pytorch.distributed import all_gather
from musiclm_pytorch.distributed import all_gather, all_gather_all_reduce_grads

from x_clip.tokenizer import tokenizer
from vector_quantize_pytorch import ResidualVQ
Expand Down Expand Up @@ -320,23 +320,37 @@ def __init__(
self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias)

self.needs_all_gather = dist.is_initialized() and dist.get_world_size() > 1

@property
def device(self):
return next(self.parameters()).device

def forward(self, audio_latents, text_latents):
device = self.device

if audio_latents.ndim == 2:
audio_latents = rearrange(audio_latents, '... -> 1 ...')

if text_latents.ndim == 2:
text_latents = rearrange(text_latents, '... -> 1 ...')

n = audio_latents.shape[1]
if self.needs_all_gather:
text_latents, batch_sizes = all_gather_all_reduce_grads(text_latents, 1, None)

n = text_latents.shape[1]

sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents)

sims = sims * self.temperatures.exp() + self.bias
labels = 2 * rearrange(torch.eye(n), 'i j -> 1 i j') - torch.ones_like(sims)

labels = torch.eye(n, device = device)

if self.needs_all_gather:
labels_by_ranks = labels.split(batch_sizes.tolist(), dim = 0)
labels = labels_by_ranks[dist.get_rank()]

labels = 2 * rearrange(labels, 'i j -> 1 i j') - torch.ones_like(sims)

return -F.logsigmoid(labels * sims).sum() / n

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'musiclm-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.3',
version = '0.2.4',
license='MIT',
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
author = 'Phil Wang',
Expand Down

0 comments on commit b5ca8ae

Please sign in to comment.