Skip to content

Commit

Permalink
add roformer
Browse files Browse the repository at this point in the history
  • Loading branch information
zh794390558 committed Jul 12, 2023
1 parent 94987f2 commit 03e9ea9
Show file tree
Hide file tree
Showing 6 changed files with 367 additions and 12 deletions.
98 changes: 98 additions & 0 deletions examples/aishell/asr1/conf/chunk_roformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: transformer # transformer, bitransformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
r_num_blocks: 3 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.3 # only for bitransformer
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence

###########################################
# Data #
###########################################

train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test


###########################################
# Dataloader #
###########################################

vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1

###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 1
global_grad_clip: 5.0
dist_sampler: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
98 changes: 98 additions & 0 deletions examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rpoe_pos' # abs_pos, rel_pos, rope_pos
selfattention_layer_type: 'rel_selfattn' # unused
causal: true
use_dynamic_chunk: true
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false
# decoder related
decoder: bitransformer # transformer, bitransformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 3
r_num_blocks: 3 # only for bitransformer
dropout_rate: 0.1 # sublayer output dropout
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
reverse_weight: 0.3 # only for bitransformer
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence

###########################################
# Data #
###########################################

train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test


###########################################
# Dataloader #
###########################################

vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: ''
unit_type: 'char'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 10.0
window_ms: 25.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 32
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1

###########################################
# Training #
###########################################
n_epoch: 240
accum_grad: 1
global_grad_clip: 5.0
dist_sampler: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
134 changes: 133 additions & 1 deletion paddlespeech/s2t/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@

logger = Log(__name__).getlog()

__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"]
__all__ = [
"MultiHeadedAttention", "RelPositionMultiHeadedAttention",
"RoPERelPositionMultiHeadedAttention"
]

# Relative Positional Encodings
# https://www.jianshu.com/p/c0608efcc26f
Expand Down Expand Up @@ -165,6 +168,7 @@ def forward(self,
and `head * d_k == size`
"""
# (B,T,D) -> (B,T,H,D/H)
q, k, v = self.forward_qkv(query, key, value)

# when export onnx model, for 1st chunk, we feed
Expand Down Expand Up @@ -373,3 +377,131 @@ def forward(self,
self.d_k) # (batch, head, time1, time2)

return self.forward_attention(v, scores, mask), new_cache


class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with RoPE relative position encoding."""

def __init__(self,
n_head,
n_feat,
dropout_rate,
adaptive_scale=False,
init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
super().__init__(n_head, n_feat, dropout_rate)

def align(self, tensor: paddle.Tensor, axes: List[int], ndim=None):
"""重新对齐tensor(批量版expand_dims)
axes:原来的第i维对齐新tensor的第axes[i]维;
ndim:新tensor的维度。
"""
assert len(axes) == tensor.dim()
assert ndim or min(axes) >= 0

ndim = ndim or max(axes) + 1

# a[0, None, 1] = a[0, np.newaxis, 1]
indices = [None] * ndim
for i in axes:
# slice nothing, a[0, slice(None), 1] = a[0, :, 1]
indices[i] = slice(None)

return tensor[indices]

def apply_rotary_position_embeddings(self, sinusoidal, *tensors):
"""应用RoPE到tensors中
其中,sinusoidal.shape=[B, T, D],tensors为tensor的列表,而
tensor.shape=[B, T, ..., D], or (B,T,H,D/H)
"""
assert len(tensors) > 0, 'at least one input tensor'
assert all(
[tensor.shape == tensors[0].shape
for tensor in tensors[1:]]), 'all tensors must have the same shape'

ndim = tensors[0].dim()

# sinusoidal shape same with tensors[0]
# [B,T,D] -> [B,T,1,D]
sinusoidal = self.align(sinusoidal, [0, 1, -1], ndim)

# https://man.hubwiz.com/docset/TensorFlow.docset/Contents/Resources/Documents/api_docs/python/tf/keras/backend/repeat_elements.html
# like np.repeat, x (s1, s2, s3), axis 1, (s1, s2*rep, s3)
# [b,T, ..., d/2] -> [b,T, ..., d]
cos_pos = paddle.repeat_interleave(sinusoidal[..., 1::2], 2, axis=-1)
sin_pos = paddle.repeat_interleave(sinusoidal[..., 0::2], 2, axis=-1)

outputs = []
for tensor in tensors:
# x2 = [-x2, x1, -x4, x3, ..., -x_d, x_{d-1}]
tensor2 = paddle.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim)
tensor2 = paddle.reshape(tensor2, paddle.shape(tensor))

# 公式 34, out = x * cos_pos + x2 * sin_pos
outputs.append(tensor * cos_pos + tensor2 * sin_pos)
return outputs[0] if len(outputs) == 1 else outputs

def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
pos_emb: paddle.Tensor=paddle.empty([0]),
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
value (paddle.Tensor): Value tensor (#batch, time2, size).
mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (paddle.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
paddle.Tensor: Output tensor (#batch, time1, d_model).
paddle.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)

# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.shape[0] > 0:
# last dim `d_k * 2` for (key, val)
key_cache, value_cache = paddle.split(cache, 2, axis=-1)
k = paddle.concat([key_cache, k], axis=2)
v = paddle.concat([value_cache, v], axis=2)
# We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)

# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
q, k = self.apply_rotary_position_embeddings(pos_emb, [q, k])
# dot(q, k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
7 changes: 5 additions & 2 deletions paddlespeech/s2t/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,17 @@ def __init__(self,
self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
self.dropout = nn.Dropout(p=dropout_rate)
self.base = 10000.0
self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D]

position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1]
# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model))
-paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
(math.log(self.base) / self.d_model))

# [B,T,D]
self.pe[:, :, 0::2] = paddle.sin(position * div_term)
self.pe[:, :, 1::2] = paddle.cos(position * div_term)

Expand Down
Loading

0 comments on commit 03e9ea9

Please sign in to comment.