Skip to content

Commit

Permalink
support detect language
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzilin committed Sep 30, 2022
1 parent 3a0e935 commit edb6944
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 30 deletions.
19 changes: 4 additions & 15 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,26 +133,15 @@ def __init__(self, model: "Whisper", initial_token_length: int):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = None
if model.type == "tiny.en":
self.kv_cache_size = lambda x, y: [8, x, y, 384]
elif model.type == "base.en":
self.kv_cache_size = lambda x, y: [12, x, y, 512]
elif model.type == "small.en":
self.kv_cache_size = lambda x, y: [24, x, y, 768]
elif model.type == "medium.en":
self.kv_cache_size = lambda x, y: [48, x, y, 1024]
else:
raise ValueError(f"Unsupported model type: {model.type}")

def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
n_group = tokens.shape[0]
if self.kv_cache is None:
self.kv_cache = np.zeros(
self.kv_cache_size(n_group, self.initial_token_length), dtype=np.float32)
self.kv_cache = self.model.new_kv_cache(n_group, self.initial_token_length)
offset = 0
else:
offset = self.kv_cache.shape[2]
new_kv_cache = np.zeros(self.kv_cache_size(n_group, offset + 1), dtype=np.float32)
new_kv_cache = self.model.new_kv_cache(n_group, offset + 1)
new_kv_cache[:, :, :-1, :] = self.kv_cache
self.kv_cache = new_kv_cache

Expand All @@ -161,7 +150,7 @@ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
tokens = tokens[:, -1:]

# export decoder as onnx
if False and self.kv_cache.shape[2] > self.initial_token_length:
if True and self.kv_cache.shape[2] > self.initial_token_length:
print(f"tokens: {tokens.shape}")
print(f"audio_features: {audio_features.shape}")
print(f"kv_cache: {self.kv_cache.shape}")
Expand Down Expand Up @@ -631,7 +620,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
print(f"step: {i}, logits: {logits}", flush=True)
print(f"step: {i}", flush=True)

if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
Expand Down
48 changes: 33 additions & 15 deletions whisper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,12 @@ def forward(
v = self.value(x if xa is None else xa)
if kv_cache is not None and k.shape[1] <= self.n_ctx:
# here is hard coded
# tiny.en: 4
# base.en: 6
# small.en: 12
# medium.en: 24
key_id = self.layer_id - 24
# tiny: 4
# base: 6
# small: 12
# medium: 24
# large: 32
key_id = self.layer_id - 4
value_id = key_id + 1
size = k.shape[1]
kv_cache[key_id, :, -size:, :] = k
Expand Down Expand Up @@ -215,10 +216,8 @@ def __init__(self, model: str):

self.core = Core()
self._model = self.core.read_model(
# hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.xml"),
# hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.bin"),
"encoder.xml",
"encoder.bin",
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.xml"),
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="encoder.bin"),
)
self.model = self.core.compile_model(self._model, "CPU")

Expand All @@ -233,10 +232,8 @@ def __init__(self, model: str):

self.core = Core()
self._model = self.core.read_model(
# hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.xml"),
# hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.bin"),
"decoder.xml",
"decoder.bin",
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.xml"),
hf_hub_download(repo_id=f"zhuzilin/whisper-openvino-{model}", filename="decoder.bin"),
)
self.model = self.core.compile_model(self._model, "CPU")

Expand Down Expand Up @@ -278,10 +275,16 @@ def embed_audio(self, mel: torch.Tensor):
return self.encoder.forward(mel)

def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder.forward(tokens, audio_features)
kv_cache = self.new_kv_cache(tokens.shape[0], tokens.shape[-1])
output, _ = self.decoder.forward(tokens, audio_features, kv_cache=torch.from_numpy(kv_cache), offset=0)
# output, _ = self.decoder.forward(tokens, audio_features, kv_cache=kv_cache, offset=0)
return output

def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
kv_cache = self.new_kv_cache(tokens.shape[0], tokens.shape[-1])
output, _ = self.decoder(tokens, self.encoder(mel), kv_cache=torch.from_numpy(kv_cache), offset=0)
# output, _ = self.decoder(tokens, self.encoder(mel), kv_cache=kv_cache, offset=0)
return output

@property
def device(self):
Expand All @@ -291,6 +294,21 @@ def device(self):
def is_multilingual(self):
return self.dims.n_vocab == 51865

def new_kv_cache(self, n_group: int, length: int):
if self.type == "tiny.en" or self.type == "tiny":
size = [8, n_group, length, 384]
elif self.type == "base.en" or self.type == "base":
size = [12, n_group, length, 512]
elif self.type == "small.en" or self.type == "small":
size = [24, n_group, length, 768]
elif self.type == "medium.en" or self.type == "medium":
size = [48, n_group, length, 1024]
elif self.type == "large":
size = [64, n_group, length, 1280]
else:
raise ValueError(f"Unsupported model type: {self.type}")
return np.zeros(size, dtype=np.float32)

detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function

0 comments on commit edb6944

Please sign in to comment.