Skip to content

Commit

Permalink
Final cleaning; Modified SSL/infer.py and README for wavlm inclusion …
Browse files Browse the repository at this point in the history
…in model options
  • Loading branch information
jiamingkong committed May 31, 2023
1 parent ba874db commit 8432e86
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion demos/speech_ssl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
```
Arguments:
- `input`(required): Audio file to recognize.
- `model`: Model type of asr task. Default: `wav2vec2`, choices: [wav2vec2, hubert].
- `model`: Model type of asr task. Default: `wav2vec2`, choices: [wav2vec2, hubert, wavlm].
- `task`: Output type. Default: `asr`.
- `lang`: Model language. Default: `en`.
- `sample_rate`: Sample rate of the model. Default: `16000`.
Expand Down
2 changes: 1 addition & 1 deletion demos/speech_ssl/README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
```
参数:
- `input`(必须输入):用于识别的音频文件。
- `model`:ASR 任务的模型,默认值:`wav2vec2`, 可选项:[wav2vec2, hubert]
- `model`:ASR 任务的模型,默认值:`wav2vec2`, 可选项:[wav2vec2, hubert, wavlm]
- `task`:输出类别,默认值:`asr`
- `lang`:模型语言,默认值:`en`
- `sample_rate`:音频采样率,默认值:`16000`
Expand Down
2 changes: 1 addition & 1 deletion examples/librispeech/asr5/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ set -e
. ./path.sh || exit 1;
. ./cmd.sh || exit 1;

gpus=1,2,3
gpus=0,1,2
stage=0
stop_stage=3
conf_path=conf/wavlmASR.yaml
Expand Down
8 changes: 7 additions & 1 deletion paddlespeech/cli/ssl/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self):
'--model',
type=str,
default='wav2vec2',
choices=['wav2vec2', 'hubert'],
choices=['wav2vec2', 'hubert', "wavlm"],
help='Choose model type of asr task.')
self.parser.add_argument(
'--task',
Expand Down Expand Up @@ -157,6 +157,12 @@ def _init_from_path(self,
elif lang == 'zh':
logger.error("zh hubertASR is not supported yet")
tag = model_prefix + '-' + lang + '-' + sample_rate_str
elif model_type == 'wavlm':
if lang == "en":
model_prefix = "wavlmASR_librispeech"
elif lang == "zh":
logger.error("zh wavlmASR is not supported yet")
tag = model_prefix + '-' + lang + '-' + sample_rate_str
else:
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(tag, version=None)
Expand Down
4 changes: 2 additions & 2 deletions paddlespeech/s2t/exps/wavlm/bin/test_wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
logger = Log(__name__).getlog()


class Wav2vec2Infer():
class WavLMInfer():
def __init__(self, config, args):
self.args = args
self.config = config
Expand Down Expand Up @@ -99,7 +99,7 @@ def check(audio_file):


def main(config, args):
Wav2vec2Infer(config, args).run()
WavLMInfer(config, args).run()


if __name__ == "__main__":
Expand Down

0 comments on commit 8432e86

Please sign in to comment.