Skip to content

Commit

Permalink
further cleanup the distributed code
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 6, 2023
1 parent f2daaf6 commit 90021a1
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 34 deletions.
54 changes: 30 additions & 24 deletions musiclm_pytorch/distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import nn
from torch.autograd import Function
import torch.distributed as distributed

Expand Down Expand Up @@ -33,37 +34,42 @@ def all_gather_variable_dim(t, dim = 0, sizes = None):

return gathered_tensor, sizes

class AllGather(Function):
class AllGatherFunction(Function):
@staticmethod
def forward(ctx, x, dim, sizes):
assert distributed.is_initialized() and distributed.get_world_size() > 1
def forward(ctx, x, dim, sizes, all_reduce_grads):
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
ctx.batch_sizes = batch_sizes.tolist()
ctx.dim = dim
return x, batch_sizes

@staticmethod
def backward(ctx, grads, _):
batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None

all_gather = AllGather.apply

class AllGatherAllReduceGrads(Function):
@staticmethod
def forward(ctx, x, dim, sizes):
assert distributed.is_initialized() and distributed.get_world_size() > 1
x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
ctx.all_reduce_grads = all_reduce_grads
ctx.batch_sizes = batch_sizes.tolist()
ctx.dim = dim
return x, batch_sizes

@staticmethod
def backward(ctx, grads, _):
distributed.all_reduce(grads)
batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None
if ctx.all_reduce_grads:
distributed.all_reduce(grads)

all_gather_all_reduce_grads = AllGatherAllReduceGrads.apply
grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None, None

class AllGather(nn.Module):
def __init__(
self,
dim,
*,
all_reduce_grads = False
):
super().__init__()
self.dim = dim
self.all_reduce_grads = all_reduce_grads
self.is_distributed = distributed.is_initialized() and distributed.get_world_size() > 1

def forward(
self,
x,
sizes = None
):
if not self.is_distributed:
return x, None

return AllGatherFunction.apply(x, self.dim, sizes, self.all_reduce_grads)
17 changes: 8 additions & 9 deletions musiclm_pytorch/musiclm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from audiolm_pytorch.utils import AudioConditionerBase

import torch.distributed as dist
from musiclm_pytorch.distributed import all_gather, all_gather_all_reduce_grads
from musiclm_pytorch.distributed import AllGather

from x_clip.tokenizer import tokenizer
from vector_quantize_pytorch import ResidualVQ
Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(
self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
self.decoupled_contrastive_learning = decoupled_contrastive_learning

self.needs_all_gather = dist.is_initialized() and dist.get_world_size() > 1
self.all_gather = AllGather(dim = 2)

@property
def device(self):
Expand All @@ -281,9 +281,9 @@ def forward(self, audio_latents, text_latents):

batch = audio_latents.shape[1]

if self.needs_all_gather:
if self.all_gather.is_distributed:
latents = torch.stack((audio_latents, text_latents))
latents, _ = all_gather(latents, 2, None)
latents, _ = self.all_gather(latents)
audio_latents, text_latents = latents

sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents)
Expand Down Expand Up @@ -320,7 +320,7 @@ def __init__(
self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias)

self.needs_all_gather = dist.is_initialized() and dist.get_world_size() > 1
self.all_gather = AllGather(dim = 1, all_reduce_grads = True)

@property
def device(self):
Expand All @@ -335,8 +335,7 @@ def forward(self, audio_latents, text_latents):
if text_latents.ndim == 2:
text_latents = rearrange(text_latents, '... -> 1 ...')

if self.needs_all_gather:
text_latents, batch_sizes = all_gather_all_reduce_grads(text_latents, 1, None)
text_latents, rank_sizes = self.all_gather(text_latents)

n = text_latents.shape[1]

Expand All @@ -346,8 +345,8 @@ def forward(self, audio_latents, text_latents):

labels = torch.eye(n, device = device)

if self.needs_all_gather:
labels_by_ranks = labels.split(batch_sizes.tolist(), dim = 0)
if exists(rank_sizes):
labels_by_ranks = labels.split(rank_sizes.tolist(), dim = 0)
labels = labels_by_ranks[dist.get_rank()]

labels = 2 * rearrange(labels, 'i j -> 1 i j') - torch.ones_like(sims)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'musiclm-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.5',
version = '0.2.6',
license='MIT',
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
author = 'Phil Wang',
Expand Down

0 comments on commit 90021a1

Please sign in to comment.