Skip to content

Commit

Permalink
Merge branch 'main' into word-level-timestamps
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Feb 2, 2023
2 parents 6c431c4 + 7858aa9 commit ff6cbfd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ A Transformer sequence-to-sequence model is trained on various speech processing

## Setup

We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8 or later and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.10 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:

pip install -U openai-whisper

Expand Down
8 changes: 7 additions & 1 deletion whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,8 @@ def apply(self, logits: Tensor, tokens: Tensor):

# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin

Expand All @@ -422,6 +423,11 @@ def apply(self, logits: Tensor, tokens: Tensor):
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf

timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
if timestamps.numel() > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf

if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
Expand Down

0 comments on commit ff6cbfd

Please sign in to comment.