Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

word-level timestamps in transcribe() #869

Merged
merged 26 commits into from
Mar 6, 2023
Merged

word-level timestamps in transcribe() #869

merged 26 commits into from
Mar 6, 2023

Conversation

jongwook
Copy link
Collaborator

No description provided.

@ryanheise
Copy link
Contributor

ryanheise commented Jan 21, 2023

This DTW dependency introduces a licence incompatibility, but an alternative was suggested earlier in the discussions from memory.

Edit: Alternative library recommended in #813 (reply in thread)

segment["words"] = []

for i, (word, start, end) in enumerate(zip(words, start_times, end_times)):
if word.startswith("<|") or word.strip() in ".,!?、。":
Copy link
Contributor

@ryanheise ryanheise Jan 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would actually be convenient to actually re-insert the punctuation tokens so that concatenating all the words is the same as concatenating all the tokens. That would just make processing easier on the consumer end. For reference, Amazon Transcribe includes timestamped punctuation tokens in the results.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice there is now a TODO comment leaning to the first approach:

if word.startswith("<|") or word.strip() in ".,!?、。":  # TODO: expand

I'm not sure if you've already committed to that approach, but I would vote for not removing the punctuation, so that whether a consumer wants to traverse the entire content by token, by word, or by segment, they can and do it in either of these 3 ways and in all the content is there (the concatenation of each result is identical). Otherwise if I consume the results by word, I would need to simultaneously look up one of the other two results to cross reference the, and look for the bits that were omitted from the sequence. Here is how Amazon Transcribe does it, for example.

On the other hand, if that is not persuasive, you might consider instead making it an option whether or not to strip out the punctuation.

(I note also that if you just left the punctuation in, the consumer would still have the ability to filter them out.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions; I've updated so that the punctuation characters are attached to the adjacent words, while keeping the word's timestamps.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that looks good. I think if the prepend_punctuations and append_punctuations parameters were propagated in transcribe() and cli() that would be quite helpful, since then I could set them to empty strings to emulate the Amazon Transcribe behaviour.

@KaiserChr
Copy link

KaiserChr commented Jan 21, 2023

Hi!
I tried out this branch with kwargs['word_level_timestamps'] = True but the model performed very slowly. In addition (or rather because of) it started to hallucinate like mad.
Im using chunks of short (couple of seconds) audio data in german produced by a VAD for live transcription.

Maybe its a problem on my side, maybe anyone can try to reproduce?

@jongwook
Copy link
Collaborator Author

Thanks for the comments, all -- this is work in progress and not quite ready for merging. I'm trying to address both hallucination and performance concerns.

@glangford
Copy link

Yet another DTW implementation, fyi. Can't vouch for it other than to say that it is Apache licensed, recently updated, has both pure Python and C implementations.

https://github.com/wannesm/dtaidistance

Copy link

@Jeronymous Jeronymous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That feature of getting word-level timestamps is useful for plenty of applications. Excellent work!

EDIT
Question: Would it be possible to hook cross-attention weights on the fly, while the model is transcribing, and avoid doing inference with the decoder twice?
It would be more efficient, and would maybe better take into account factors like the context (conditioning on previous text...).

Comment on lines 143 to 144
tokenizer.timestamp_begin + mel.shape[-1] // 2,
tokenizer.eot,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't thing these last two tokens are needed to estimate word timestamps.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a trick to nudge the DTW path to go along these tokens so that the last few words have more accurate timestamp. It's still not perfect, but I settled with using <|no_timestamps|> token and no timestamp tokens in the recent commit.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that it's important to have attention weights used to predict the timestamp token for the end of speech segment, but these attention weights are the one you get when the input is the last predicted (sub)word token. I think it's enough. When the input token is the final timestamp the decoder is already focusing on predicting the next thing.
I wonder if things are not shifted by one, because it was a problem I saw with your notebook (the timestamps were assigned to the token before the one it should be).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are anecdotally seeing that too, in our tests. The timestamps lag a word for example Got no empirical proof ( very anecdotal ).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the same observations.

I did notice in earlier commits that the next token after a comma may lag as if the comma was taking up too much time. That seems to have become more accurate in later commits.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, we could not come with any empirical evidence either. May it was the previous version.

for hook in hooks:
hook.remove()

weights = torch.cat(QKs[-6:]) # layers * heads * tokens * frames

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only considering (at most) the last 6 layers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was because the attention weights in the later layers were more indicative of the time alignment. I've updated this part, and now it uses a mask to select which layers and heads to find the alignment.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok interesting. I have to check this masking trick.
I don't understand why later layers are more indicative. Is it an intuition that I am missing, or some empirical results you got from experiments?

continue

segment = segments[token_sources[word_boundaries[i]]]
segment["words"].append(dict(word=word, start=round(start, 2), end=round(end, 2)))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add a confidence score based on the average log proba for each words?
This can be a useful feature, available with very little additional computations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point! Added in 5fa4356

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome 👍

Comment on lines +64 to +65
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe one of these two is not needed, as it doesn't really make sense to attribute a same timestamp to several tokens.
Well... not fully sure. Maybe it's useful when a lot of text has to be aligned with a small portion of audio (which can happen when Whisper "inner language model" is stuck).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it totally makes sense to force a token to have at least one timestamp, which is only about 20 milliseconds. I left this as-is, to handle some failure cases like repetition looping as you mentioned; in the post-processing zero-length segments are removed, and it was usually the case for the generation got stuck on repetition looping.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have seen words repeating with same end and start timestamps. Segments were fine. I think we have to move the segment code where you handle duplication ( repetition looping ) to words too.



@triton.jit
def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triton jit requires gcc/clang, python-dev, and cuda-dev at runtime. Please consider some lighter-weight alternatives.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up; I made it fall back to pytorch and numba implementations if triton fails with RuntimeError or subprocess.CalledProcessError. I haven't tested this on a non-dev environment, so please feel free to ping me if the fallback does not work for any reason.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. Thank you.

@hojinYang
Copy link

hojinYang commented Feb 11, 2023

Hi, thanks for the great work!

I would like to ask if it is safe to swap to a smaller model (e.g. tiny) for world-level alignment to compute attention scores instead of using the same model (e.g. medium or large ) used to generate transcription. I suspect it could improve performance in terms of inference speed if this option would be supported.


# heads * tokens * frames
weights = torch.stack([QKs[l][h] for l, h in model.alignment_heads.indices().T])
weights = weights[:, :, : num_frames // 2]
Copy link

@saunair saunair Feb 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Why get weights only for the first num_frames/2 frames?



@pytest.mark.parametrize("N, M", sizes)
def test_dtw(N: int, M: int):
Copy link

@saunair saunair Feb 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was your reason to not use the dtw library licensing concerns or just speedup?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtw-python is GPL, as mentioned here -
#869 (comment)



@pytest.mark.parametrize("shape", shapes)
def test_median_filter(shape):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Is there a licensing issue using scipy.median_filter or is this cuda implementation just faster?

@ryanheise
Copy link
Contributor

I found an interesting edge case with the small model where enabling the word-level timestamps option causes it to repeat the prompt at the end of the audio while also failing to infer the last word.

$ ffmpeg -t 29 -i https://audio2.redcircle.com/episodes/6b196013-8672-43d9-be52-4332b3207d93/stream.mp3 test.mp3

$ whisper --model small test.mp3
.../whisper/transcribe.py:98: UserWarning: FP16 is not supported on CPU; using FP32 instead
  warnings.warn("FP16 is not supported on CPU; using FP32 instead")
Detecting language using up to the first 30 seconds. Use `--language` to specify the language
Detected language: English
[00:00.000 --> 00:15.920]  Military veteran Eric Weinstein began 69 Whiskey as a college radio show on 107.7 The
[00:15.920 --> 00:21.720]  Bronx, located on the campus of Ryder University in Lawrenceville, New Jersey.
[00:21.720 --> 00:27.560]  A show once restrained by rules and boundaries now comes straight to you raw, uncensored and
[00:27.560 --> 00:28.960]  unapologetic.

$ whisper --model small --output_format json --word_timestamps True test.mp3
.../whisper/transcribe.py:98: UserWarning: FP16 is not supported on CPU; using FP32 instead
  warnings.warn("FP16 is not supported on CPU; using FP32 instead")
Detecting language using up to the first 30 seconds. Use `--language` to specify the language
Detected language: English
[00:08.040 --> 00:15.940]  Military veteran Eric Weinstein began 69 Whiskey as a college radio show on 107.7 The
[00:15.940 --> 00:21.320]  Bronx, located on the campus of Ryder University in Lawrenceville, New Jersey.
[00:21.720 --> 00:28.980]  A show once restrained by rules and boundaries now comes straight to you raw, uncensored and
[00:28.960 --> 00:28.960]  Military veteran Eric Weinstein began 69 Whiskey as a college radio show on 107.7 The
[00:28.960 --> 00:28.960]  Bronx, located on the campus of Ryder University in Lawrenceville, New Jersey.
[00:28.960 --> 00:28.960]  A show once restrained by rules and boundaries now comes straight to you raw, uncensored and
[00:28.960 --> 00:28.960]  Military veteran Eric Weinstein began 69 Whiskey as a college radio show on 107.7 The
[00:28.960 --> 00:28.960]  Bronx, located on the campus of Ryder University in Lawrenceville, New Jersey.

@IgnacioSan22
Copy link

Hi @jongwook ,
Since you first release the notebook to obtain word-level timestamps I've been working on this to add to whisper process. And I've tried to test other alingment methods than DTW. Have you tried something else and found out that it works better?

Also, I've been struggling a lot with alucinations, specially for spanish content. I've create a cleaner function at segmet level, is there any smarter way?

@ioskevinshah
Copy link

is there any chance to have word level timestamps in Whisper API?

@jongwook
Copy link
Collaborator Author

jongwook commented Mar 6, 2023

Hi @IgnacioSan22, the custom DTW implementation in this PR was for the license issue as noted by others and also for the speed. An alternative is to use the timestamp predictions from the model, but we found that it's less reliable than using the attention patterns like in this PR. If you have solutions using any other algorithms for alignment, please let me know!

The community had some success handling hallucinations by preprocessing the inputs with VAD, like:


Hi @ioskevinshah, this feature is still experimental but we do plan to add it to the API as an option, once we're sure that it's reliable enough.

@jongwook jongwook merged commit 500d0fe into main Mar 6, 2023
@jongwook jongwook deleted the word-level-timestamps branch March 6, 2023 22:01
@JeffreyWardman
Copy link

@jongwook is there a way to access it via a beta flag for instance? How can we know when something is/isn't added to the API?

@jongwook
Copy link
Collaborator Author

jongwook commented Mar 8, 2023

For the API, the speech-to-text guide and the audio API reference provide the full documentation of the available features. These documents will be updated accordingly as we roll out new features.

@IgnacioSan22
Copy link

Hi @IgnacioSan22, the custom DTW implementation in this PR was for the license issue as noted by others and also for the speed. An alternative is to use the timestamp predictions from the model, but we found that it's less reliable than using the attention patterns like in this PR. If you have solutions using any other algorithms for alignment, please let me know!

The community had some success handling hallucinations by preprocessing the inputs with VAD, like:

Hi @ioskevinshah, this feature is still experimental but we do plan to add it to the API as an option, once we're sure that it's reliable enough.

Hi @jongwook, I've tried the hungarian algorithm and in some cases the results are better, however due to the lack of resources I'm not capable to perform a proper study to find the best alingment algorithm. For hallucinations I've developed a postprocess functions that cleans the segments. It improves quite a lot, but I'll check those references.

Thanks

@glinft
Copy link

glinft commented Mar 15, 2023

For the API, the speech-to-text guide and the audio API reference provide the full documentation of the available features. These documents will be updated accordingly as we roll out new features.

One more question: when will this new feature be rolled out?

@ioskevinshah
Copy link

Hi @IgnacioSan22, the custom DTW implementation in this PR was for the license issue as noted by others and also for the speed. An alternative is to use the timestamp predictions from the model, but we found that it's less reliable than using the attention patterns like in this PR. If you have solutions using any other algorithms for alignment, please let me know!

The community had some success handling hallucinations by preprocessing the inputs with VAD, like:

Hi @ioskevinshah, this feature is still experimental but we do plan to add it to the API as an option, once we're sure that it's reliable enough.

any workaround or logic after the API response?

zackees pushed a commit to zackees/whisper that referenced this pull request May 5, 2023
* word-level timestamps in `transcribe()`

* moving to `timing.py`

* numba implementation for dtw, replacing dtw-python

* triton implementation for dtw

* add test for dtw implementations

* triton implementation of median_filter

* a simple word-level timestamps test

* add scipy as dev dependency

* installs an older version of Triton if CUDA < 11.4

* fix broken merge

* loosen nvcc version match regex

* find_alignment() function

* miscellaneous improvements

* skip median filtering when the input is too small

* Expose punctuation options in cli and transcribe() (openai#973)

* fix merge error

* fix merge error 2

* annotating that word_timestamps is experimental

---------

Co-authored-by: ryanheise <[email protected]>
ilanit1997 pushed a commit to ilanit1997/whisper that referenced this pull request May 16, 2023
* word-level timestamps in `transcribe()`

* moving to `timing.py`

* numba implementation for dtw, replacing dtw-python

* triton implementation for dtw

* add test for dtw implementations

* triton implementation of median_filter

* a simple word-level timestamps test

* add scipy as dev dependency

* installs an older version of Triton if CUDA < 11.4

* fix broken merge

* loosen nvcc version match regex

* find_alignment() function

* miscellaneous improvements

* skip median filtering when the input is too small

* Expose punctuation options in cli and transcribe() (openai#973)

* fix merge error

* fix merge error 2

* annotating that word_timestamps is experimental

---------

Co-authored-by: ryanheise <[email protected]>
@samuelbradshaw
Copy link

This is awesome! Is there a way to pass in pre-transcribed text that whisper can use for more accurate alignment?

abyesilyurt pushed a commit to abyesilyurt/whisper that referenced this pull request Nov 13, 2023
* word-level timestamps in `transcribe()`

* moving to `timing.py`

* numba implementation for dtw, replacing dtw-python

* triton implementation for dtw

* add test for dtw implementations

* triton implementation of median_filter

* a simple word-level timestamps test

* add scipy as dev dependency

* installs an older version of Triton if CUDA < 11.4

* fix broken merge

* loosen nvcc version match regex

* find_alignment() function

* miscellaneous improvements

* skip median filtering when the input is too small

* Expose punctuation options in cli and transcribe() (openai#973)

* fix merge error

* fix merge error 2

* annotating that word_timestamps is experimental

---------

Co-authored-by: ryanheise <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet