Skip to content

Commit

Permalink
fix all_tokens handling that caused more repetitions and discrepancy …
Browse files Browse the repository at this point in the history
…in JSON (openai#1060)
  • Loading branch information
jongwook committed Mar 8, 2023
1 parent aac47c9 commit 38f2f4d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
1 change: 1 addition & 0 deletions tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_transcribe(model_name: str):
audio_path, language=language, temperature=0.0, word_timestamps=True
)
assert result["language"] == "en"
assert result["text"] == "".join([s["text"] for s in result["segments"]])

transcription = result["text"].lower()
assert "my fellow americans" in transcription
Expand Down
2 changes: 1 addition & 1 deletion whisper/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def add_word_timestamps(
if len(segments) == 0:
return

text_tokens = [t for segment in segments for t in segment["tokens"]]
text_tokens = [t for s in segments for t in s["tokens"] if t < tokenizer.eot]
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
merge_punctuations(alignment, prepend_punctuations, append_punctuations)

Expand Down
22 changes: 12 additions & 10 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,14 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
):
text_tokens = [token for token in tokens.tolist() if token < tokenizer.eot]
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"id": len(all_segments),
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": text_tokens,
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
Expand Down Expand Up @@ -245,7 +245,6 @@ def new_segment(

previous_seek = seek
current_segments = []
current_tokens = []

timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
Expand Down Expand Up @@ -275,7 +274,6 @@ def new_segment(
result=result,
)
)
current_tokens.append(sliced_tokens.tolist())
last_slice = current_slice

if single_timestamp_ending:
Expand All @@ -287,7 +285,6 @@ def new_segment(
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
all_tokens.extend(tokens[: last_slice + 1].tolist())
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
Expand All @@ -309,7 +306,6 @@ def new_segment(
result=result,
)
)
current_tokens.append(tokens.tolist())
seek += segment_size

if not condition_on_previous_text or result.temperature > 0.5:
Expand Down Expand Up @@ -348,11 +344,17 @@ def new_segment(
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
current_tokens[i] = []

all_segments.extend(current_segments)
all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend(
[token for segment in current_tokens for token in segment]
[token for segment in current_segments for token in segment["tokens"]]
)

# update progress bar
Expand Down

0 comments on commit 38f2f4d

Please sign in to comment.