Skip to content

Commit

Permalink
precompute audio and text layers to be used for hierarchical contrast…
Browse files Browse the repository at this point in the history
…ive loss
  • Loading branch information
lucidrains committed Mar 2, 2023
1 parent 5f53d41 commit cad7b49
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
21 changes: 11 additions & 10 deletions musiclm_pytorch/musiclm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,13 +526,10 @@ def forward(

# hierarchical cl loss

def pick_layers_evenly_interspersed(layers, tensor):
total_layers, device = tensor.shape[0], tensor.device
def interspersed_indices(layers, total_layers):
assert total_layers >= layers

step = total_layers / layers
indices = (torch.arange(0, layers) * step).floor().long()
return tensor[indices]
return (torch.arange(0, layers) * step).floor().long()

class MultiLayerContrastiveLoss(nn.Module):
def __init__(
Expand Down Expand Up @@ -564,9 +561,6 @@ def __init__(
def forward(self, *, audio_layers, text_layers):
device, batch = audio_layers.device, audio_layers.shape[1]

audio_layers = pick_layers_evenly_interspersed(self.layers, audio_layers)
text_layers = pick_layers_evenly_interspersed(self.layers, text_layers)

audio_gap = reduce(audio_layers, 'l b n d -> l b d', 'mean')
audio_embeds = self.audio_norm(audio_gap) * self.audio_gamma
audio_latents = einsum('l b d, l d e -> l b e', audio_embeds, self.audio_latent_weight) + self.audio_latent_bias
Expand Down Expand Up @@ -603,7 +597,8 @@ def __init__(
text_transformer: TextTransformer,
dim_latent = 128, # they use 128
decoupled_contrastive_learning = True, # think this was used, make it optional
hierarchical_contrastive_loss = False
hierarchical_contrastive_loss = False,
hierarchical_contrastive_loss_layers = None
):
super().__init__()
self.dim_latent = dim_latent
Expand All @@ -621,9 +616,12 @@ def __init__(
self.multi_layer_contrastive_learning = None

if hierarchical_contrastive_loss:
num_layers = min(audio_transformer.depth, text_transformer.depth) - 1
num_layers = default(hierarchical_contrastive_loss_layers, min(audio_transformer.depth, text_transformer.depth) - 1)
assert num_layers > 0

self.register_buffer('text_layers_indices', interspersed_indices(num_layers, text_transformer.depth))
self.register_buffer('audio_layers_indices', interspersed_indices(num_layers, audio_transformer.depth))

self.multi_layer_contrastive_learning = MultiLayerContrastiveLoss(
audio_dim = self.audio.dim,
text_dim = self.text.dim,
Expand Down Expand Up @@ -706,6 +704,9 @@ def forward(
if not exists(self.multi_layer_contrastive_learning):
return cl_loss

audio_layers = audio_layers[self.audio_layers_indices]
text_layers = text_layers[self.text_layers_indices]

# whether to do cl loss across all layers, from ViCHA paper https://arxiv.org/abs/2208.13628

hierarchical_cl_loss = self.multi_layer_contrastive_learning(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'musiclm-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
author = 'Phil Wang',
Expand Down

0 comments on commit cad7b49

Please sign in to comment.