Skip to content

Commit

Permalink
move beartype onto methods
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 8, 2023
1 parent 7422185 commit 87e835e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
11 changes: 7 additions & 4 deletions musiclm_pytorch/musiclm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,8 @@ def forward(

# text transformer

@beartype
class TextTransformer(nn.Module):
@beartype
def __init__(
self,
dim,
Expand Down Expand Up @@ -546,6 +546,7 @@ def __init__(
def device(self):
return next(self.parameters()).device

@beartype
def forward(
self,
x = None,
Expand Down Expand Up @@ -648,8 +649,8 @@ def forward(self, *, audio_layers, text_layers):

# main classes

@beartype
class MuLaN(nn.Module):
@beartype
def __init__(
self,
audio_transformer: AudioSpectrogramTransformer,
Expand Down Expand Up @@ -705,6 +706,7 @@ def get_audio_latents(

return out, audio_layers

@beartype
def get_text_latents(
self,
texts = None,
Expand All @@ -720,6 +722,7 @@ def get_text_latents(

return out, text_layers

@beartype
def forward(
self,
wavs,
Expand Down Expand Up @@ -766,8 +769,8 @@ def forward(

# music lm

@beartype
class MuLaNEmbedQuantizer(AudioConditionerBase):
@beartype
def __init__(
self,
mulan: MuLaN,
Expand Down Expand Up @@ -851,8 +854,8 @@ def forward(
cond_embeddings = cond_embeddings.gather(2, indices)
return rearrange(cond_embeddings, 'b q 1 d -> b q d')

@beartype
class MusicLM(nn.Module):
@beartype
def __init__(
self,
audio_lm: AudioLM,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'musiclm-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.1',
version = '0.2.2',
license='MIT',
description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
author = 'Phil Wang',
Expand Down

0 comments on commit 87e835e

Please sign in to comment.