Skip to content

Commit

Permalink
gwelliannau/improvements & Common Voice 8
Browse files Browse the repository at this point in the history
  • Loading branch information
DewiBrynJones committed Jan 31, 2022
1 parent 4f7656e commit 513be4f
Show file tree
Hide file tree
Showing 29 changed files with 766 additions and 215 deletions.
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
models
homedir
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@


[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5270295.svg)](https://doi.org/10.5281/zenodo.5270295)


# docker-wav2vec2-xlsr-ft-cy

[(click here to read the README in English)](README_en.md)
Expand Down
3 changes: 2 additions & 1 deletion inference/.dockerignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
models
recordings
recordings
data
3 changes: 2 additions & 1 deletion inference/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
data
models
recordings
recordings
40 changes: 31 additions & 9 deletions inference/Makefile
Original file line number Diff line number Diff line change
@@ -1,32 +1,54 @@
default: build

$(eval DEVICE = cpu)
#$(eval DEVICE = gpu)

config:
# to use a local model, provide the full /models/.... path for WAV2VEC2_MODEL_NAME and
# leave the MODEL_VERSION blank empty string.
$(eval WAV2VEC2_MODEL_NAME = techiaith/wav2vec2-xlsr-ft-cy)
$(eval MODEL_VERSION = 21.08)
mkdir -p ${PWD}/data/


build: config
docker build --rm -t techiaith/wav2vec2-xlsr-ft-cy \
docker build --rm -t techiaith/wav2vec2-xlsr-ft-cy-${USER} \
--build-arg WAV2VEC2_MODEL_NAME=${WAV2VEC2_MODEL_NAME} \
--build-arg MODEL_VERSION=${MODEL_VERSION} \
.

run: config run-${DEVICE}

run: config
mkdir -p ${PWD}/recordings/
docker run --name techiaith-wav2vec2-xlsr-ft-cy \
run-gpu:
docker run --gpus all --name techiaith-wav2vec2-xlsr-ft-cy-${USER} \
--restart=always \
-it \
-v ${PWD}/models/:/models \
-v ${PWD}/recordings/:/recordings \
techiaith/wav2vec2-xlsr-ft-cy
-v ${PWD}/data/:/data \
techiaith/wav2vec2-xlsr-ft-cy-${USER}

run-cpu:
docker run --name techiaith-wav2vec2-xlsr-ft-cy-${USER} \
--restart=always \
-it \
-v ${PWD}/models/:/models \
-v ${PWD}/data/:/data \
techiaith/wav2vec2-xlsr-ft-cy-${USER}


fetch-test:
if [ ! -d "data/corpws-profi-adnabod-lleferydd" ]; then \
mkdir -p data; \
cd data && git clone -b fersiwn2 --single-branch https://git.techiaith.bangor.ac.uk/data-porth-technolegau-iaith/corpws-profi-adnabod-lleferydd.git; \
fi

stop: config
-docker stop techiaith-wav2vec2-xlsr-ft-cy
-docker rm techiaith-wav2vec2-xlsr-ft-cy
-docker stop techiaith-wav2vec2-xlsr-ft-cy-${USER}
-docker rm techiaith-wav2vec2-xlsr-ft-cy-${USER}


clean: config stop
-docker rmi techiaith/wav2vec2-xlsr-ft-cy
-docker rmi techiaith/wav2vec2-xlsr-ft-cy-${USER}

purge: clean
sudo rm -rf models
97 changes: 75 additions & 22 deletions inference/python/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
import yaml
import tarfile
import urllib.request
from urllib.parse import urlparse

from pathlib import Path
from tqdm import tqdm

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from ctcdecode import CTCBeamDecoder


class DownloadProgressBar(tqdm):
Expand All @@ -14,40 +17,90 @@ def update_to(self, b=1, bsize=1, tsize=None):
self.total = tsize
self.update(b * bsize - self.n)


def default_root_dir():
return os.path.join("/", "models")


def create(model_path, revision):

def download_file(models_root_dir, model_name, version, file_name):
# expecting model name as HuggingFace Model name e.g. techiaith/wav2vec2-xlsr-ft-cy
# expecting file name within the HuggingFace model git repository e.g. kenlm.tar.gz
cache_dir=model_path

model_file = os.path.join(models_root_dir, file_name)

if Path(model_file).is_file():
print ("Model file {} already downloaded.".format(model_file))
# initialize acoustic model...
#
if Path(model_path).is_dir():
# from a local directory containing our own trained model
print("Initiaising wav2vec2 model from local directory: %s" % model_path)
processor = Wav2Vec2Processor.from_pretrained(model_path)
model = Wav2Vec2ForCTC.from_pretrained(model_path)
else:
print ("Downloading {} version {}".format(file_name, version))
Path(models_root_dir).mkdir(parents=True, exist_ok=True)

file_url = os.path.join("https://huggingface.co", model_name, "resolve", version, file_name)
download_and_extract(file_url, model_file)
# from the HuggingFace models repository.
print("Initialising wav2vec2 model \"%s\" from HuggingFace model repository" % model_path)
cache_dir = os.path.join('/models', model_path)
processor = Wav2Vec2Processor.from_pretrained(model_path, cache_dir=cache_dir, revision=revision)
model = Wav2Vec2ForCTC.from_pretrained(model_path, cache_dir=cache_dir, revision=revision)

# initialize language model...
#
targz_file_path=os.path.join(cache_dir, "kenlm.tar.gz")
if not Path(targz_file_path).is_file():
print ("Downloading kenlm language model version {}".format(revision))
try:
file_url = os.path.join("https://huggingface.co", model_path, "resolve", revision, 'kenlm.tar.gz')
download(file_url, os.path.join(cache_dir, targz_file_path))
except Exception as e:
print (e)

if not Path(os.path.join(cache_dir, "config_ctc.yaml")).is_file():
if Path(targz_file_path).is_file():
extract(targz_file_path)

return models_root_dir
if Path(os.path.join(cache_dir, "config_ctc.yaml")).is_file():
with open(os.path.join(cache_dir, "config_ctc.yaml"), 'r') as config_file:
ctc_lm_params=yaml.load(config_file, Loader=yaml.FullLoader)

#
vocab=processor.tokenizer.convert_ids_to_tokens(range(0, processor.tokenizer.vocab_size))
space_ix = vocab.index('|')
vocab[space_ix]=' '

ctcdecoder = CTCBeamDecoder(vocab,
model_path='',
alpha=0,
beta=0,
cutoff_top_n=40,
cutoff_prob=1.0,
beam_width=100,
num_processes=4,
blank_id=processor.tokenizer.pad_token_id,
log_probs_input=True
)

def download_and_extract(file_url, output_file_path):
kenlm_ctcdecoder=None
if Path(os.path.join(cache_dir, "lm.binary")).is_file():
if ctc_lm_params:
print ("Initializing KenLM language model...")
kenlm_ctcdecoder = CTCBeamDecoder(vocab,
model_path=os.path.join(cache_dir, "lm.binary"),
alpha=ctc_lm_params['alpha'],
beta=ctc_lm_params['beta'],
cutoff_top_n=40,
cutoff_prob=1.0,
beam_width=100,
num_processes=4,
blank_id=processor.tokenizer.pad_token_id,
log_probs_input=True
)

#
return processor, model, vocab, ctcdecoder, kenlm_ctcdecoder


def download(file_url, output_file_path):
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=file_url.split('/')[-1]) as t:
urllib.request.urlretrieve(file_url, filename=output_file_path, reporthook=t.update_to)

def extract(targz_file_path):
# extract.
if output_file_path.endswith(".tar.gz"):
if targz_file_path.endswith(".tar.gz"):
print ("Extracting...")
model_dir = Path(output_file_path).parent.absolute()
tar = tarfile.open(output_file_path, "r:gz")
model_dir = Path(targz_file_path).parent.absolute()
tar = tarfile.open(targz_file_path, "r:gz")
tar.extractall(model_dir)
tar.close()

Expand Down
9 changes: 7 additions & 2 deletions inference/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ numpy
wave
jiwer
webrtcvad
transformers
transformers==4.9.2
tqdm==4.61.0
pyyaml==5.4.1
torchaudio==0.7.2
torch==1.7.1
datasets
librosa
srt
praatio<5
praatio<5
pydub
pandas
python_speech_features
scipy
Loading

0 comments on commit 513be4f

Please sign in to comment.