Skip to content

Commit

Permalink
cli add ds2-librispeech offline, fix versionm, test=asr
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackwaterveg committed Jan 27, 2022
1 parent 4128f4d commit 2a42421
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 15 deletions.
35 changes: 23 additions & 12 deletions paddlespeech/cli/asr/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2offline_librispeech-en-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'f5666c81ad015c8de03aac2bc92e5762',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm',
'lm_md5':
'099a601759d467cd0a8523ff939819c5'
},
}

model_alias = {
Expand Down Expand Up @@ -328,18 +342,15 @@ def infer(self, model_type: str):
audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"]
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
result_transcripts = self.model.decode(
audio,
audio_len,
self.text_feature.vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
decode_batch_size = audio.shape[0]
self.model.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)

result_transcripts = self.model.decode(audio, audio_len)
self.model.decoder.del_decoder()
self._outputs["result"] = result_transcripts[0]

elif "conformer" in model_type or "transformer" in model_type:
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
try:
from .. import __version__
except ImportError:
__version__ = 0.0.0 # for develop branch
__version__ = "0.0.0" # for develop branch

requests.adapters.DEFAULT_RETRIES = 3

Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/io/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/t2s/modules/transformer/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def repeat(N, fn):
MultiSequential
Repeated model instance.
"""
return MultiSequential(* [fn(n) for n in range(N)])
return MultiSequential(*[fn(n) for n in range(N)])

0 comments on commit 2a42421

Please sign in to comment.