Skip to content

Commit

Permalink
handle if even amount of batch sizes across devices
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 6, 2023
1 parent 80a9826 commit 4a78b41
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
28 changes: 18 additions & 10 deletions musiclm_pytorch/distributed.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
import torch
from torch import nn
from torch.autograd import Function
import torch.distributed as distributed
import torch.distributed as dist

from einops import rearrange

# distributed helpers

def all_gather_same_dim(t):
world_size = dist.get_world_size()
gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
dist.all_gather(gathered_tensors, t)
return gathered_tensors

def all_gather_variable_dim(t, dim = 0, sizes = None):
device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size()
device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()

if not exists(sizes):
size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
distributed.all_gather(sizes, size)
sizes = all_gather_same_dim(size)
sizes = torch.stack(sizes)

if torch.unique(sizes).numel() == 1:
gathered_tensors = all_gather_same_dim(t)
return torch.cat(gathered_tensors, dim = dim), sizes

max_size = sizes.amax().item()
padded_t = pad_dim_to(t, max_size, dim = dim)

gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
distributed.all_gather(gathered_tensors, padded_t)
padded_t = pad_dim_to(t, max_size, dim = dim)
gathered_tensors = all_gather_same_dim(padded_t)

gathered_tensor = torch.cat(gathered_tensors, dim = dim)
seq = torch.arange(max_size, device = device)
Expand All @@ -45,9 +53,9 @@ def forward(ctx, x, dim, sizes, all_reduce_grads):

@staticmethod
def backward(ctx, grads, _):
batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
if ctx.all_reduce_grads:
distributed.all_reduce(grads)
dist.all_reduce(grads)

grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
return grads_by_rank[rank], None, None, None
Expand All @@ -62,7 +70,7 @@ def __init__(
super().__init__()
self.dim = dim
self.all_reduce_grads = all_reduce_grads
self.is_distributed = distributed.is_initialized() and distributed.get_world_size() > 1
self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'musiclm-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.7',
version = '0.2.8',
license='MIT',
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
author = 'Phil Wang',
Expand Down

0 comments on commit 4a78b41

Please sign in to comment.