Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

word-level timestamps in transcribe() #869

Merged
merged 26 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8f9357f
word-level timestamps in `transcribe()`
jongwook Jan 20, 2023
46ea501
moving to `timing.py`
jongwook Jan 21, 2023
cfd2b81
Merge branch 'main' into word-level-timestamps
jongwook Jan 21, 2023
742d2f4
numba implementation for dtw, replacing dtw-python
jongwook Jan 22, 2023
fb12414
Merge branch 'main' into word-level-timestamps
jongwook Jan 22, 2023
80331c0
triton implementation for dtw
jongwook Jan 23, 2023
1d2ed66
add test for dtw implementations
jongwook Jan 23, 2023
b61e8f4
triton implementation of median_filter
jongwook Jan 24, 2023
54f2901
a simple word-level timestamps test
jongwook Jan 24, 2023
8ce6277
add scipy as dev dependency
jongwook Jan 24, 2023
812f446
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
cd5191f
installs an older version of Triton if CUDA < 11.4
jongwook Jan 24, 2023
f64d8bc
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
89133bd
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
d4f9399
fix broken merge
jongwook Jan 24, 2023
040aa04
Merge branch 'main' into word-level-timestamps
jongwook Jan 24, 2023
8e2756b
loosen nvcc version match regex
jongwook Jan 25, 2023
6c431c4
find_alignment() function
jongwook Jan 25, 2023
ff6cbfd
Merge branch 'main' into word-level-timestamps
jongwook Feb 2, 2023
5fa4356
miscellaneous improvements
jongwook Feb 2, 2023
48537aa
skip median filtering when the input is too small
jongwook Feb 2, 2023
8eb29c3
Expose punctuation options in cli and transcribe() (#973)
ryanheise Feb 16, 2023
6ed4c11
Merge branch 'main' into word-level-timestamps
jongwook Mar 6, 2023
31cd418
fix merge error
jongwook Mar 6, 2023
145f325
fix merge error 2
jongwook Mar 6, 2023
2b079c4
annotating that word_timestamps is experimental
jongwook Mar 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,5 @@ jobs:
- run: conda install -n test ffmpeg python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} cpuonly -c pytorch
- uses: actions/checkout@v2
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install pytest
- run: pip install .
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]'
- run: pip install .["dev"]
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numba
numpy
torch
tqdm
Expand Down
20 changes: 18 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys

import pkg_resources
from setuptools import setup, find_packages
Expand All @@ -9,6 +10,21 @@ def read_version(fname="whisper/version.py"):
return locals()["__version__"]


requirements = []
if sys.platform.startswith("linux"):
triton_requirement = "triton>=2.0.0.dev20221202"
try:
import re
import subprocess
version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1]
major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0]
if (int(major), int(minor)) < (11, 4):
# the last version supporting CUDA < 11.4
triton_requirement = "triton==2.0.0.dev20221011"
except (IndexError, OSError, subprocess.SubprocessError):
pass
requirements.append(triton_requirement)

setup(
name="openai-whisper",
py_modules=["whisper"],
Expand All @@ -22,7 +38,7 @@ def read_version(fname="whisper/version.py"):
url="https://github.com/openai/whisper",
license="MIT",
packages=find_packages(exclude=["tests*"]),
install_requires=[
install_requires=requirements + [
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
Expand All @@ -32,5 +48,5 @@ def read_version(fname="whisper/version.py"):
"console_scripts": ["whisper=whisper.transcribe:cli"],
},
include_package_data=True,
extras_require={"dev": ["pytest"]},
extras_require={"dev": ["pytest", "scipy"]},
)
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import random as rand

import numpy
import pytest


def pytest_configure(config):
config.addinivalue_line("markers", "requires_cuda")


@pytest.fixture
def random():
rand.seed(42)
numpy.random.seed(42)
87 changes: 87 additions & 0 deletions tests/test_timing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest
import numpy as np
import scipy.ndimage
import torch

from whisper.timing import dtw_cpu, dtw_cuda, median_filter


sizes = [
(10, 20), (32, 16), (123, 1500), (234, 189),
]
shapes = [
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512),
]


@pytest.mark.parametrize("N, M", sizes)
def test_dtw(N: int, M: int):
Copy link

@saunair saunair Feb 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was your reason to not use the dtw library licensing concerns or just speedup?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtw-python is GPL, as mentioned here -
#869 (comment)

steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)])
np.random.shuffle(steps)
x = np.random.random((N, M)).astype(np.float32)

i, j, k = 0, 0, 0
trace = []
while True:
x[i, j] -= 1
trace.append((i, j))

if k == len(steps):
break

if k + 1 < len(steps) and steps[k] != steps[k + 1]:
i += 1
j += 1
k += 2
continue

if steps[k] == 0:
i += 1
if steps[k] == 1:
j += 1
k += 1

trace = np.array(trace).T
dtw_trace = dtw_cpu(x)

assert np.allclose(trace, dtw_trace)


@pytest.mark.requires_cuda
@pytest.mark.parametrize("N, M", sizes)
def test_dtw_cuda_equivalence(N: int, M: int):
x_numpy = np.random.randn(N, M).astype(np.float32)
x_cuda = torch.from_numpy(x_numpy).cuda()

trace_cpu = dtw_cpu(x_numpy)
trace_cuda = dtw_cuda(x_cuda)

assert np.allclose(trace_cpu, trace_cuda)


@pytest.mark.parametrize("shape", shapes)
def test_median_filter(shape):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Is there a licensing issue using scipy.median_filter or is this cuda implementation just faster?

x = torch.randn(*shape)

for filter_width in [3, 5, 7, 13]:
filtered = median_filter(x, filter_width)

# using np.pad to reflect-pad, because Scipy's behavior is different near the edges.
pad_width = filter_width // 2
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect")
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width])
scipy_filtered = scipy_filtered[..., pad_width:-pad_width]

assert np.allclose(filtered, scipy_filtered)


@pytest.mark.requires_cuda
@pytest.mark.parametrize("shape", shapes)
def test_median_filter_equivalence(shape):
x = torch.randn(*shape)

for filter_width in [3, 5, 7, 13]:
filtered_cpu = median_filter(x, filter_width)
filtered_gpu = median_filter(x.cuda(), filter_width).cpu()

assert np.allclose(filtered_cpu, filtered_gpu)
14 changes: 13 additions & 1 deletion tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,22 @@ def test_transcribe(model_name: str):
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")

language = "en" if model_name.endswith(".en") else None
result = model.transcribe(audio_path, language=language, temperature=0.0)
result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True)
assert result["language"] == "en"

transcription = result["text"].lower()
assert "my fellow americans" in transcription
assert "your country" in transcription
assert "do for you" in transcription

timing_checked = False
for segment in result["segments"]:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
print(timing)
timing_checked = True

assert timing_checked
22 changes: 22 additions & 0 deletions whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}

# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
"large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj',
}



def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
Expand Down Expand Up @@ -106,8 +123,10 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow

if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

Expand All @@ -119,4 +138,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])

if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)

return model.to(device)
4 changes: 4 additions & 0 deletions whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input

N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each)
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each)


def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Expand Down
13 changes: 12 additions & 1 deletion whisper/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import gzip
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
Expand All @@ -8,8 +10,8 @@
from torch import Tensor
from torch import nn

from .transcribe import transcribe as transcribe_function
from .decoding import detect_language as detect_language_function, decode as decode_function
from .transcribe import transcribe as transcribe_function


@dataclass
Expand Down Expand Up @@ -213,6 +215,15 @@ def __init__(self, dims: ModelDimensions):
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half layers for alignment by default; see `set_alignment_heads()` below
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
all_heads[self.dims.n_text_layer // 2:] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
Expand Down
Loading