diff --git a/examples/aishell/asr1/conf/chunk_roformer.yaml b/examples/aishell/asr1/conf/chunk_roformer.yaml index 1b752f87db..a4051a021c 100644 --- a/examples/aishell/asr1/conf/chunk_roformer.yaml +++ b/examples/aishell/asr1/conf/chunk_roformer.yaml @@ -18,7 +18,7 @@ encoder_conf: cnn_module_kernel: 15 use_cnn_module: True activation_type: 'swish' - pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos + pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos selfattention_layer_type: 'rel_selfattn' # unused causal: true use_dynamic_chunk: true @@ -30,7 +30,7 @@ decoder_conf: attention_heads: 4 linear_units: 2048 num_blocks: 6 - r_num_blocks: 3 # only for bitransformer + r_num_blocks: 0 # only for bitransformer dropout_rate: 0.1 # sublayer output dropout positional_dropout_rate: 0.1 self_attention_dropout_rate: 0.0 @@ -39,7 +39,7 @@ decoder_conf: model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option - reverse_weight: 0.3 # only for bitransformer + reverse_weight: 0.0 # only for bitransformer length_normalized_loss: false init_type: 'kaiming_uniform' # !Warning: need to convergence diff --git a/examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml b/examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml index 8bf81fa078..aa3a0aca76 100644 --- a/examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml +++ b/examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml @@ -18,7 +18,7 @@ encoder_conf: cnn_module_kernel: 15 use_cnn_module: True activation_type: 'swish' - pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos + pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos selfattention_layer_type: 'rel_selfattn' # unused causal: true use_dynamic_chunk: true diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index f716fa3b57..2e1c14ac10 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -145,7 +145,6 @@ def forward( text_lengths) ctc_time = time.time() - start #logger.debug(f"ctc time: {ctc_time}") - if loss_ctc is None: loss = loss_att elif loss_att is None: @@ -916,6 +915,8 @@ def _init_from_config(cls, configs: dict): decoder_type = configs.get('decoder', 'transformer') logger.debug(f"U2 Decoder type: {decoder_type}") if decoder_type == 'transformer': + configs['model_conf'].pop('reverse_weight', None) + configs['decoder_conf'].pop('r_num_blocks', None) decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 386977cd2d..2ab931d64c 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -16,6 +16,7 @@ """Multi-Head Attention layer definition.""" import math from typing import Tuple +from typing import List import paddle from paddle import nn @@ -418,25 +419,27 @@ def align(self, tensor: paddle.Tensor, axes: List[int], ndim=None): def apply_rotary_position_embeddings(self, sinusoidal, *tensors): """应用RoPE到tensors中 其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而 - tensor.shape=[B, T, ..., D], or (B,T,H,D/H) + tensor.shape=[B, T, ..., D], or (B,H,T,D/H) """ assert len(tensors) > 0, 'at least one input tensor' assert all( [tensor.shape == tensors[0].shape for tensor in tensors[1:]]), 'all tensors must have the same shape' + # (B,H,T,D) ndim = tensors[0].dim() + _,H,T,D = tensors[0].shape # sinusoidal shape same with tensors[0] - # [B,T,D] -> [B,T,1,D] - sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim) + # [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H) + # sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim) + sinusoidal = sinusoidal.reshape((1, T, H, D)).transpose([0, 2, 1, 3]) # http://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html # like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3) # [b,T, ..., d/2] -> [b,T, ..., d] cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1) sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1) - outputs = [] for tensor in tensors: # x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}] @@ -501,7 +504,7 @@ def forward(self, new_cache = paddle.concat((k, v), axis=-1) # f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index - q, k = self.apply_rotary_position_embeddings(pos_emb, [q, k]) + q, k = self.apply_rotary_position_embeddings(pos_emb, q, k) # dot(q, k) scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask), new_cache diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 2c3b8c39f3..91247d9779 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -477,9 +477,10 @@ def __init__(self, activation = get_activation(activation_type) # self-attention module definition + encoder_dim = output_size if pos_enc_layer_type == "abs_pos": encoder_selfattn_layer = MultiHeadedAttention - encoder_selfattn_layer_args = (attention_heads, output_size, + encoder_selfattn_layer_args = (attention_heads, encoder_dim, attention_dropout_rate) elif pos_enc_layer_type == "rel_pos": encoder_selfattn_layer = RelPositionMultiHeadedAttention @@ -495,16 +496,16 @@ def __init__(self, # feed-forward module definition positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = (output_size, linear_units, dropout_rate, + positionwise_layer_args = (encoder_dim, linear_units, dropout_rate, activation) # convolution module definition convolution_layer = ConvolutionModule - convolution_layer_args = (output_size, cnn_module_kernel, activation, + convolution_layer_args = (encoder_dim, cnn_module_kernel, activation, cnn_module_norm, causal) self.encoders = nn.LayerList([ ConformerEncoderLayer( - size=output_size, + size=encoder_dim, self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args), feed_forward=positionwise_layer(*positionwise_layer_args), feed_forward_macaron=positionwise_layer(