Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zh794390558 committed Jul 12, 2023
1 parent 03e9ea9 commit 55870ff
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 14 deletions.
6 changes: 3 additions & 3 deletions examples/aishell/asr1/conf/chunk_roformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ encoder_conf:
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos
pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
Expand All @@ -30,7 +30,7 @@ decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
r_num_blocks: 3 # only for bitransformer
r_num_blocks: 0 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
Expand All @@ -39,7 +39,7 @@ decoder_conf:
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.3 # only for bitransformer
reverse_weight: 0.0 # only for bitransformer
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence

Expand Down
2 changes: 1 addition & 1 deletion examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ encoder_conf:
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos
pos_enc_layer_type: 'rope_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
Expand Down
3 changes: 2 additions & 1 deletion paddlespeech/s2t/models/u2/u2.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def forward(
text_lengths)
ctc_time = time.time() - start
#logger.debug(f"ctc time: {ctc_time}")

if loss_ctc is None:
loss = loss_att
elif loss_att is None:
Expand Down Expand Up @@ -916,6 +915,8 @@ def _init_from_config(cls, configs: dict):
decoder_type = configs.get('decoder', 'transformer')
logger.debug(f"U2 Decoder type: {decoder_type}")
if decoder_type == 'transformer':
configs['model_conf'].pop('reverse_weight', None)
configs['decoder_conf'].pop('r_num_blocks', None)
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
Expand Down
13 changes: 8 additions & 5 deletions paddlespeech/s2t/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Multi-Head Attention layer definition."""
import math
from typing import Tuple
from typing import List

import paddle
from paddle import nn
Expand Down Expand Up @@ -418,25 +419,27 @@ def align(self, tensor: paddle.Tensor, axes: List[int], ndim=None):
def apply_rotary_position_embeddings(self, sinusoidal, *tensors):
"""应用RoPE到tensors中
其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而
tensor.shape=[B, T, ..., D], or (B,T,H,D/H)
tensor.shape=[B, T, ..., D], or (B,H,T,D/H)
"""
assert len(tensors) > 0, 'at least one input tensor'
assert all(
[tensor.shape == tensors[0].shape
for tensor in tensors[1:]]), 'all tensors must have the same shape'

# (B,H,T,D)
ndim = tensors[0].dim()
_,H,T,D = tensors[0].shape

# sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,1,D]
sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
# [B,T,D] -> [B,T,H,D/H] -> (B,H,T,D/H)
# sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)
sinusoidal = sinusoidal.reshape((1, T, H, D)).transpose([0, 2, 1, 3])

# http:https://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1)
sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1)

outputs = []
for tensor in tensors:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
Expand Down Expand Up @@ -501,7 +504,7 @@ def forward(self,
new_cache = paddle.concat((k, v), axis=-1)

# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
q, k = self.apply_rotary_position_embeddings(pos_emb, [q, k])
q, k = self.apply_rotary_position_embeddings(pos_emb, q, k)
# dot(q, k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
9 changes: 5 additions & 4 deletions paddlespeech/s2t/modules/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,10 @@ def __init__(self,
activation = get_activation(activation_type)

# self-attention module definition
encoder_dim = output_size
if pos_enc_layer_type == "abs_pos":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, output_size,
encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate)
elif pos_enc_layer_type == "rel_pos":
encoder_selfattn_layer = RelPositionMultiHeadedAttention
Expand All @@ -495,16 +496,16 @@ def __init__(self,

# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (output_size, linear_units, dropout_rate,
positionwise_layer_args = (encoder_dim, linear_units, dropout_rate,
activation)
# convolution module definition
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation,
convolution_layer_args = (encoder_dim, cnn_module_kernel, activation,
cnn_module_norm, causal)

self.encoders = nn.LayerList([
ConformerEncoderLayer(
size=output_size,
size=encoder_dim,
self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),
feed_forward=positionwise_layer(*positionwise_layer_args),
feed_forward_macaron=positionwise_layer(
Expand Down

0 comments on commit 55870ff

Please sign in to comment.