Skip to content

Commit

Permalink
add sigmoid contrastive loss
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 30, 2023
1 parent b95e54b commit 80ad8cc
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 40 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples
}
```

```bibtex
@inproceedings{Zhai2023SigmoidLF,
title = {Sigmoid Loss for Language Image Pre-Training},
author = {Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer},
year = {2023}
}
```

*The only truth is music.* - Jack Kerouac

*Music is the universal language of mankind.* - Henry Wadsworth Longfellow
123 changes: 84 additions & 39 deletions musiclm_pytorch/musiclm_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from functools import wraps

import torch
Expand Down Expand Up @@ -248,6 +249,76 @@ def forward(

return x, torch.stack(layers[:-1])

# contrastive losses

class SoftmaxContrastiveLearning(nn.Module):
def __init__(
self,
*,
layers = 1,
decoupled_contrastive_learning = False,
init_temp = 10
):
super().__init__()
self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
self.decoupled_contrastive_learning = decoupled_contrastive_learning

@property
def device(self):
return next(self.parameters()).device

def forward(self, sims):
batch = sims.shape[-1]

if sims.ndim == 2:
sims = rearrange(sims, 'i j -> 1 i j')

sims = sims * self.temperatures.exp()

cosine_sims_exp = sims.exp()

numerator = matrix_diag(cosine_sims_exp)

if self.decoupled_contrastive_learning:
eye = torch.eye(batch, device = self.device, dtype = torch.bool)
cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.)

denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum')
denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum')

contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))

contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean')
return contrastive_loss.sum()

class SigmoidContrastiveLearning(nn.Module):
""" https://arxiv.org/abs/2303.15343 """

def __init__(
self,
*,
layers = 1,
init_temp = 10,
init_bias = -10
):
super().__init__()
self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias)

@property
def device(self):
return next(self.parameters()).device

def forward(self, sims):
if sims.ndim == 2:
sims = rearrange(sims, 'i j -> 1 i j')

n = sims.shape[-1]
sims = sims * self.temperatures.exp() + self.bias
labels = 2 * rearrange(torch.eye(n), 'i j -> 1 i j') - torch.ones_like(sims)

return -F.logsigmoid(labels * sims).sum() / n

# Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778

def pair(t):
Expand Down Expand Up @@ -539,7 +610,8 @@ def __init__(
text_dim,
dim_latent,
layers,
decoupled_contrastive_learning = False
decoupled_contrastive_learning = False,
sigmoid_contrastive_loss = False
):
super().__init__()
self.layers = layers
Expand All @@ -554,9 +626,8 @@ def __init__(
self.text_latent_weight = nn.Parameter(torch.randn(layers, text_dim, dim_latent))
self.text_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent))

self.temperatures = nn.Parameter(torch.ones(layers, 1, 1))

self.decoupled_contrastive_learning = decoupled_contrastive_learning
klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning)
self.contrast = klass(layers = layers)

def forward(self, *, audio_layers, text_layers):
device, batch = audio_layers.device, audio_layers.shape[1]
Expand All @@ -571,23 +642,9 @@ def forward(self, *, audio_layers, text_layers):
text_latents = einsum('l b d, l d e -> l b e', text_embeds, self.text_latent_weight) + self.text_latent_bias
text_latents = l2norm(text_latents)

cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) * self.temperatures.exp()

cosine_sims_exp = cosine_sims.exp()
cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents)

numerator = matrix_diag(cosine_sims_exp)

if self.decoupled_contrastive_learning:
eye = torch.eye(batch, device = device, dtype = torch.bool)
cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.)

denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum')
denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum')

contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))

contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean')
return contrastive_loss.sum()
return self.contrast(cosine_sims)

# main classes

Expand All @@ -600,20 +657,21 @@ def __init__(
dim_latent = 128, # they use 128
decoupled_contrastive_learning = True, # think this was used, make it optional
hierarchical_contrastive_loss = False,
hierarchical_contrastive_loss_layers = None
hierarchical_contrastive_loss_layers = None,
sigmoid_contrastive_loss = False
):
super().__init__()
self.dim_latent = dim_latent

self.audio = audio_transformer
self.text = text_transformer

self.temperature = nn.Parameter(torch.tensor(1.))

self.text_to_latents = nn.Linear(self.text.dim, dim_latent)
self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)

self.decoupled_contrastive_learning = decoupled_contrastive_learning
klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning)
self.contrast = klass()

self.multi_layer_contrastive_learning = None

Expand All @@ -629,7 +687,8 @@ def __init__(
text_dim = self.text.dim,
dim_latent = dim_latent,
layers = num_layers,
decoupled_contrastive_learning = decoupled_contrastive_learning
decoupled_contrastive_learning = decoupled_contrastive_learning,
sigmoid_contrastive_loss = sigmoid_contrastive_loss
)

def get_audio_latents(
Expand Down Expand Up @@ -688,21 +747,7 @@ def forward(
if return_pairwise_similarities:
return cosine_sim

cosine_sim = cosine_sim * self.temperature.exp()

cosine_sim_exp = cosine_sim.exp()

numerator = cosine_sim_exp.diag()

if self.decoupled_contrastive_learning:
eye = torch.eye(batch, device = device, dtype = torch.bool)
cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.)

denominator_i = reduce(cosine_sim_exp, 'i j -> i', 'sum')
denominator_j = reduce(cosine_sim_exp, 'i j -> j', 'sum')

contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))
cl_loss = contrastive_loss.mean()
cl_loss = self.contrast(cosine_sim)

if not exists(self.multi_layer_contrastive_learning):
return cl_loss
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'musiclm-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.2',
version = '0.2.0',
license='MIT',
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
author = 'Phil Wang',
Expand Down

0 comments on commit 80ad8cc

Please sign in to comment.