Skip to content

Commit

Permalink
support bitransformer decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Zth9730 committed Sep 21, 2022
1 parent 027535d commit d3e5937
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion paddlespeech/audio/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor,
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype)
max_len = paddle.max(ys_lens)
index_range = paddle.arange(0, max_len, 1)
seq_len_expand = ys_lens.unsqueeze(1)
Expand Down Expand Up @@ -279,6 +279,7 @@ def paddle_gather(x, dim, index):
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
eos = paddle.full([1], eos, dtype=r_hyps.dtype)
r_hyps = paddle.where(seq_mask, r_hyps, eos)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/io/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def get_dataloader(mode: str, config, args):
elif mode == 'valid':
config['manifest'] = config.dev_manifest
config['train_mode'] = False
elif model == 'test' or mode == 'align':
elif mode == 'test' or mode == 'align':
config['manifest'] = config.test_manifest
config['train_mode'] = False
config['dither'] = 0.0
Expand Down

0 comments on commit d3e5937

Please sign in to comment.