Skip to content

Commit

Permalink
[ASR] rm transformers import and modify variable name consistent with…
Browse files Browse the repository at this point in the history
… infer.py, test=asr (#2929)

* rm transformers import and modify variable name consistent with infer.py

* add condition ctc_prefix_beam_search decode.
  • Loading branch information
zxcd authored Feb 15, 2023
1 parent 71bda24 commit 004a4d6
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
8 changes: 3 additions & 5 deletions paddlespeech/s2t/exps/wav2vec2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import jsonlines
import numpy as np
import paddle
import transformers
from hyperpyyaml import load_hyperpyyaml
from paddle import distributed as dist
from paddlenlp.transformers import AutoTokenizer

from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
Expand Down Expand Up @@ -530,8 +530,7 @@ def dataio_prepare(self, hparams):
datasets = [train_data, valid_data, test_data]

# Defining tokenizer and loading it
tokenizer = transformers.BertTokenizer.from_pretrained(
'bert-base-chinese')
tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese')
self.tokenizer = tokenizer
# 2. Define audio pipeline:
@data_pipeline.takes("wav")
Expand Down Expand Up @@ -867,8 +866,7 @@ def test(self):
vocab_list = self.vocab_list
decode_batch_size = decode_cfg.decode_batch_size

with jsonlines.open(
self.args.result_file, 'w', encoding='utf8') as fout:
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
if self.use_sb:
metrics = self.sb_compute_metrics(**batch, fout=fout)
Expand Down
18 changes: 13 additions & 5 deletions paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from turtle import Turtle
from typing import Dict
from typing import List
from typing import Tuple
Expand Down Expand Up @@ -83,6 +84,7 @@ def decode(self,
text_feature: Dict[str, int],
decoding_method: str,
beam_size: int,
tokenizer: str=None,
sb_pipeline=False):
batch_size = feats.shape[0]

Expand All @@ -93,12 +95,15 @@ def decode(self,
logger.error(f"current batch_size is {batch_size}")

if decoding_method == 'ctc_greedy_search':
if not sb_pipeline:
if tokenizer is None and sb_pipeline is False:
hyps = self.ctc_greedy_search(feats)
res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps]
else:
hyps = self.ctc_greedy_search(feats.unsqueeze(-1))
if sb_pipeline is True:
hyps = self.ctc_greedy_search(feats.unsqueeze(-1))
else:
hyps = self.ctc_greedy_search(feats)
res = []
res_tokenids = []
for sequence in hyps:
Expand All @@ -123,13 +128,16 @@ def decode(self,
# with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1
if not sb_pipeline:
if tokenizer is None and sb_pipeline is False:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp]
else:
hyp = self.ctc_prefix_beam_search(
feats.unsqueeze(-1), beam_size)
if sb_pipeline is True:
hyp = self.ctc_prefix_beam_search(
feats.unsqueeze(-1), beam_size)
else:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = []
res_tokenids = []
predicted_tokens = text_feature.convert_ids_to_tokens(hyp)
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
"paddleslim>=2.3.4",
"paddleaudio>=1.1.0",
"hyperpyyaml",
"transformers",
]

server = ["pattern_singleton", "websockets"]
Expand Down

0 comments on commit 004a4d6

Please sign in to comment.