Skip to content

Commit

Permalink
word-level timestamps in transcribe() (openai#869)
Browse files Browse the repository at this point in the history
* word-level timestamps in `transcribe()`

* moving to `timing.py`

* numba implementation for dtw, replacing dtw-python

* triton implementation for dtw

* add test for dtw implementations

* triton implementation of median_filter

* a simple word-level timestamps test

* add scipy as dev dependency

* installs an older version of Triton if CUDA < 11.4

* fix broken merge

* loosen nvcc version match regex

* find_alignment() function

* miscellaneous improvements

* skip median filtering when the input is too small

* Expose punctuation options in cli and transcribe() (openai#973)

* fix merge error

* fix merge error 2

* annotating that word_timestamps is experimental

---------

Co-authored-by: ryanheise <[email protected]>
  • Loading branch information
2 people authored and abyesilyurt committed Nov 13, 2023
1 parent 19874b7 commit 04ecf97
Show file tree
Hide file tree
Showing 14 changed files with 802 additions and 88 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,5 @@ jobs:
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
- uses: actions/checkout@v2
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install pytest
- run: pip install .
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
- run: pip install .["dev"]
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numba
numpy
torch
tqdm
Expand Down
20 changes: 18 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys

import pkg_resources
from setuptools import setup, find_packages
Expand All @@ -9,6 +10,21 @@ def read_version(fname="whisper/version.py"):
return locals()["__version__"]


requirements = []
if sys.platform.startswith("linux"):
triton_requirement = "triton>=2.0.0.dev20221202"
try:
import re
import subprocess
version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
if (int(major), int(minor)) < (11, 4):
# the last version supporting CUDA < 11.4
triton_requirement = "triton==2.0.0.dev20221011"
except (IndexError, OSError, subprocess.SubprocessError):
pass
requirements.append(triton_requirement)

setup(
name="openai-whisper",
py_modules=["whisper"],
Expand All @@ -22,7 +38,7 @@ def read_version(fname="whisper/version.py"):
url="https://github.com/openai/whisper",
license="MIT",
packages=find_packages(exclude=["tests*"]),
install_requires=[
install_requires=requirements + [
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
Expand All @@ -32,5 +48,5 @@ def read_version(fname="whisper/version.py"):
"console_scripts": ["whisper=whisper.transcribe:cli"],
},
include_package_data=True,
extras_require={"dev": ["pytest"]},
extras_require={"dev": ["pytest", "scipy"]},
)
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import random as rand

import numpy
import pytest


def pytest_configure(config):
config.addinivalue_line("markers", "requires_cuda")


@pytest.fixture
def random():
rand.seed(42)
numpy.random.seed(42)
87 changes: 87 additions & 0 deletions tests/test_timing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest
import numpy as np
import scipy.ndimage
import torch

from whisper.timing import dtw_cpu, dtw_cuda, median_filter


sizes = [
(10, 20), (32, 16), (123, 1500), (234, 189),
]
shapes = [
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
]


@pytest.mark.parametrize("N, M", sizes)
def test_dtw(N: int, M: int):
steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
np.random.shuffle(steps)
x = np.random.random((N, M)).astype(np.float32)

i, j, k = 0, 0, 0
trace = []
while True:
x[i, j] -= 1
trace.append((i, j))

if k == len(steps):
break

if k + 1 < len(steps) and steps[k] != steps[k + 1]:
i += 1
j += 1
k += 2
continue

if steps[k] == 0:
i += 1
if steps[k] == 1:
j += 1
k += 1

trace = np.array(trace).T
dtw_trace = dtw_cpu(x)

assert np.allclose(trace, dtw_trace)


@pytest.mark.requires_cuda
@pytest.mark.parametrize("N, M", sizes)
def test_dtw_cuda_equivalence(N: int, M: int):
x_numpy = np.random.randn(N, M).astype(np.float32)
x_cuda = torch.from_numpy(x_numpy).cuda()

trace_cpu = dtw_cpu(x_numpy)
trace_cuda = dtw_cuda(x_cuda)

assert np.allclose(trace_cpu, trace_cuda)


@pytest.mark.parametrize("shape", shapes)
def test_median_filter(shape):
x = torch.randn(*shape)

for filter_width in [3, 5, 7, 13]:
filtered = median_filter(x, filter_width)

# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
pad_width = filter_width // 2
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]

assert np.allclose(filtered, scipy_filtered)


@pytest.mark.requires_cuda
@pytest.mark.parametrize("shape", shapes)
def test_median_filter_equivalence(shape):
x = torch.randn(*shape)

for filter_width in [3, 5, 7, 13]:
filtered_cpu = median_filter(x, filter_width)
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()

assert np.allclose(filtered_cpu, filtered_gpu)
15 changes: 14 additions & 1 deletion tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def test_transcribe():
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")

language = "en" if model_name.endswith(".en") else None
result = model.transcribe(audio_path, language=language, temperature=0.0)
result = model.transcribe(
audio_path, language=language, temperature=0.0, word_timestamps=True
)
assert result["language"] == "en"

transcription = result["text"].lower()
Expand Down Expand Up @@ -43,3 +45,14 @@ def test_transcribe_callback():
assert "my fellow americans" in transcription
assert "your country" in transcription
assert "do for you" in transcription
timing_checked = False
for segment in result["segments"]:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
print(timing)
timing_checked = True

assert timing_checked
22 changes: 22 additions & 0 deletions whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}

# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
"large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
}



def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
Expand Down Expand Up @@ -106,8 +123,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow

if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

Expand All @@ -119,4 +138,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])

if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)

return model.to(device)
4 changes: 4 additions & 0 deletions whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input

N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each)
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each)


def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Expand Down
13 changes: 12 additions & 1 deletion whisper/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import gzip
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
Expand All @@ -8,8 +10,8 @@
from torch import Tensor
from torch import nn

from .transcribe import transcribe as transcribe_function
from .decoding import detect_language as detect_language_function, decode as decode_function
from .transcribe import transcribe as transcribe_function


@dataclass
Expand Down Expand Up @@ -213,6 +215,15 @@ def __init__(self, dims: ModelDimensions):
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half layers for alignment by default; see `set_alignment_heads()` below
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
all_heads[self.dims.n_text_layer // 2:] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
Expand Down
Loading

0 comments on commit 04ecf97

Please sign in to comment.