From 14eb31c7ee09a92fc06d59ad075aa602237b7d50 Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Tue, 26 Nov 2019 17:09:22 +0800 Subject: [PATCH 1/9] init qk attention --- src/operator/contrib/transformer.cc | 94 +++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/src/operator/contrib/transformer.cc b/src/operator/contrib/transformer.cc index 2ca6f8c71093..a863968e31d8 100644 --- a/src/operator/contrib/transformer.cc +++ b/src/operator/contrib/transformer.cc @@ -122,6 +122,99 @@ static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs, return true; } +void strided_batch_sgemm(bool transA, bool transB, + index_t m, index_t n, index_t k, + float alpha, const float *a, index_t lda, + index_t strideA, const float *b, index_t ldb, + index_t strideB, float beta, float *c, index_t ldc, + index_t strideC, int32_t batchCount) { +#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) + const int GROUP_SIZE = 1; + MKL_INT p_m[GROUP_SIZE] = {m}; + MKL_INT p_n[GROUP_SIZE] = {n}; + MKL_INT p_k[GROUP_SIZE] = {k}; + MKL_INT p_lda[GROUP_SIZE] = {lda}; + MKL_INT p_ldb[GROUP_SIZE] = {ldb}; + MKL_INT p_ldc[GROUP_SIZE] = {ldc}; + + float p_alpha[GROUP_SIZE] = {alpha}; + float p_beta[GROUP_SIZE] = {beta}; + + CBLAS_TRANSPOSE cblas_a_trans = transA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE cblas_b_trans = transB ? CblasTrans : CblasNoTrans; + + MKL_INT p_group_sizeb[GROUP_SIZE] = {batchCount}; + CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans}; + CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans}; + + std::vector pp_A(batchCount, nullptr); + std::vector pp_B(batchCount, nullptr); + std::vector pp_C(batchCount, nullptr); + + for (int i = 0; i < batchCount; i++) { + pp_A[i] = a + i * strideA; + pp_B[i] = b + i * strideB; + pp_C[i] = c + i * strideC; + } + + cblas_sgemm_batch(CblasRowMajor, p_transa, p_transb, + p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(), + p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb); +#else + for (int i = 0; i < batchCount; ++i) { + cblas_sgemm(CblasRowMajor, transa, transb, m, n, k, alpha, + pp_A[i], lda, pp_B[i], ldb, beta, pp_C[i], ldc); + } +#endif +} + +void InterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + + if (req[0] == kNullOp) + return; + + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; + float* output = outputs[0].FlatTo2D(s).dptr_; + + const index_t qkv_seq_len = inputs[0].shape_[0]; + const index_t sequences = inputs[0].shape_[1]; + const index_t output_lin_dim = inputs[0].shape_[2]; + const index_t embed_dim = output_lin_dim / 3; + const index_t head_dim = embed_dim / params.heads; + const index_t attn_batches = params.heads * sequences; + const index_t lead_dim = attn_batches * 3 * head_dim; + const index_t batch_stride = 3 * head_dim; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + const float scale = 1.0 / sqrt(static_cast(head_dim)); + + strided_batch_sgemm(false, + true, + qkv_seq_len, + qkv_seq_len, + head_dim, + scale, + queries_keys_values + head_dim, + lead_dim, + batch_stride, + queries_keys_values, + lead_dim, + batch_stride, + beta, + output, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + attn_batches); +} + NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_qk) .describe(R"code(Compute the matrix multiplication between the projections of queries and keys in multihead attention use as self attention. @@ -152,6 +245,7 @@ This Op is GPU only }) .set_attr("FInferShape", InterleavedMatMulSelfAttQKShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", InterleavedMatMulSelfAttQKCPU) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_qk"}) .add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values") From 29c0623869d7f44e67a5fd647b6808b48d45f3fa Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Sun, 1 Dec 2019 05:59:27 +0800 Subject: [PATCH 2/9] qk: fake backward --- src/operator/contrib/transformer.cc | 39 +++++++++++++++++++---------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/operator/contrib/transformer.cc b/src/operator/contrib/transformer.cc index a863968e31d8..a608570e5735 100644 --- a/src/operator/contrib/transformer.cc +++ b/src/operator/contrib/transformer.cc @@ -128,6 +128,17 @@ void strided_batch_sgemm(bool transA, bool transB, index_t strideA, const float *b, index_t ldb, index_t strideB, float beta, float *c, index_t ldc, index_t strideC, int32_t batchCount) { + + std::vector pp_A(batchCount, nullptr); + std::vector pp_B(batchCount, nullptr); + std::vector pp_C(batchCount, nullptr); + + for (int i = 0; i < batchCount; i++) { + pp_A[i] = a + i * strideA; + pp_B[i] = b + i * strideB; + pp_C[i] = c + i * strideC; + } + #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) const int GROUP_SIZE = 1; MKL_INT p_m[GROUP_SIZE] = {m}; @@ -147,23 +158,17 @@ void strided_batch_sgemm(bool transA, bool transB, CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans}; CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans}; - std::vector pp_A(batchCount, nullptr); - std::vector pp_B(batchCount, nullptr); - std::vector pp_C(batchCount, nullptr); - - for (int i = 0; i < batchCount; i++) { - pp_A[i] = a + i * strideA; - pp_B[i] = b + i * strideB; - pp_C[i] = c + i * strideC; - } - cblas_sgemm_batch(CblasRowMajor, p_transa, p_transb, p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(), p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb); #else for (int i = 0; i < batchCount; ++i) { - cblas_sgemm(CblasRowMajor, transa, transb, m, n, k, alpha, - pp_A[i], lda, pp_B[i], ldb, beta, pp_C[i], ldc); + cblas_sgemm(CblasRowMajor, + transA ? CblasTrans : CblasNoTrans, + transB ? CblasTrans : CblasNoTrans, + m, n, k, + alpha, pp_A[i], lda, + pp_B[i], ldb, beta, pp_C[i], ldc); } #endif } @@ -215,6 +220,13 @@ void InterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs, attn_batches); } +void BackwardInterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { +} + NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_qk) .describe(R"code(Compute the matrix multiplication between the projections of queries and keys in multihead attention use as self attention. @@ -255,7 +267,8 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_qk) .set_num_inputs(2) .set_num_outputs(1) .set_attr("TIsBackward", true) -.set_attr_parser(ParamParser); +.set_attr_parser(ParamParser) +.set_attr("FCompute", BackwardInterleavedMatMulSelfAttQKCPU); NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_valatt) .describe(R"code(Compute the matrix multiplication between the projections of From 795de79bfb1ee73e69c61ef97b9ef40c00a70ffb Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Mon, 2 Dec 2019 22:46:55 +0800 Subject: [PATCH 3/9] cpu selfatt and encdec; move tests to test_operator.py --- src/operator/contrib/transformer.cc | 449 ++++++++++++++++++++++++- tests/python/unittest/test_operator.py | 324 ++++++++++++++++++ 2 files changed, 758 insertions(+), 15 deletions(-) diff --git a/src/operator/contrib/transformer.cc b/src/operator/contrib/transformer.cc index a608570e5735..b135a80afcc6 100644 --- a/src/operator/contrib/transformer.cc +++ b/src/operator/contrib/transformer.cc @@ -158,12 +158,12 @@ void strided_batch_sgemm(bool transA, bool transB, CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans}; CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans}; - cblas_sgemm_batch(CblasRowMajor, p_transa, p_transb, + cblas_sgemm_batch(CblasColMajor, p_transa, p_transb, p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(), p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb); #else for (int i = 0; i < batchCount; ++i) { - cblas_sgemm(CblasRowMajor, + cblas_sgemm(CblasColMajor, transA ? CblasTrans : CblasNoTrans, transB ? CblasTrans : CblasNoTrans, m, n, k, @@ -201,8 +201,8 @@ void InterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs, const float beta = req[0] == kAddTo ? 1.f : 0.f; const float scale = 1.0 / sqrt(static_cast(head_dim)); - strided_batch_sgemm(false, - true, + strided_batch_sgemm(true, + false, qkv_seq_len, qkv_seq_len, head_dim, @@ -225,6 +225,427 @@ void BackwardInterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) + return; + + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + const float* output_grads = inputs[0].FlatTo2D(s).dptr_; + const float* queries_keys_values = inputs[1].FlatTo2D(s).dptr_; + float* queries_keys_values_grads = outputs[0].FlatTo2D(s).dptr_; + const index_t qkv_seq_len = inputs[1].shape_[0]; + const index_t sequences = inputs[1].shape_[1]; + const index_t output_lin_dim = inputs[1].shape_[2]; + const index_t embed_dim = output_lin_dim / 3; + const index_t head_dim = embed_dim / params.heads; + const index_t attn_batches = params.heads * sequences; + const index_t lead_dim = attn_batches * 3 * head_dim; + const index_t batch_stride = 3 * head_dim; + const float scale = 1.0 / sqrt(static_cast(head_dim)); + const float beta = req[0] == kAddTo ? 1.f : 0.f; + + if (req[0] == kWriteTo) { + memset(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof (float)); + } + + strided_batch_sgemm(false, + false, + head_dim, + qkv_seq_len, + qkv_seq_len, + scale, + queries_keys_values + head_dim, + lead_dim, + batch_stride, + output_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads, + lead_dim, + batch_stride, + attn_batches); + + strided_batch_sgemm(false, + true, + head_dim, + qkv_seq_len, + qkv_seq_len, + scale, + queries_keys_values, + lead_dim, + batch_stride, + output_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads + head_dim, + lead_dim, + batch_stride, + attn_batches); +} + +void InterleavedMatMulSelfAttValAttCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) + return; + + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; + const float* attention_maps = inputs[1].FlatTo2D(s).dptr_; + float* output = outputs[0].FlatTo2D(s).dptr_; + const index_t qkv_seq_len = inputs[0].shape_[0]; + const index_t sequences = inputs[0].shape_[1]; + const index_t output_lin_dim = inputs[0].shape_[2]; + const index_t embed_dim = output_lin_dim / 3; + const index_t head_dim = embed_dim / params.heads; + const index_t attn_batches = params.heads * sequences; + const index_t lead_dim = attn_batches * 3 * head_dim; + const index_t batch_stride = 3 * head_dim; + const float alpha = 1.f; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + + strided_batch_sgemm(false, + false, + head_dim, + qkv_seq_len, + qkv_seq_len, + alpha, + queries_keys_values + 2 * head_dim, + lead_dim, + batch_stride, + attention_maps, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + output, + head_dim * attn_batches, + head_dim, + attn_batches); +} + +void BackwardInterleavedMatMulSelfAttValAttCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) + return; + + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* output_grads = inputs[0].FlatTo2D(s).dptr_; + const float* queries_keys_values = inputs[1].FlatTo2D(s).dptr_; + const float* attention_maps = inputs[2].FlatTo2D(s).dptr_; + float* queries_keys_values_grads = outputs[0].FlatTo2D(s).dptr_; + float* attention_maps_grads = outputs[1].FlatTo2D(s).dptr_; + const index_t qkv_seq_len = inputs[1].shape_[0]; + const index_t sequences = inputs[1].shape_[1]; + const index_t output_lin_dim = inputs[1].shape_[2]; + const index_t embed_dim = output_lin_dim / 3; + const index_t head_dim = embed_dim / params.heads; + const index_t attn_batches = params.heads * sequences; + const index_t lead_dim = attn_batches * 3 * head_dim; + const index_t batch_stride = 3 * head_dim; + const float alpha = 1.f; + if (req[0] != kNullOp) { + + if (req[0] == kWriteTo) { + memset(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof (float)); + } + const float beta = req[0] == kAddTo ? 1.f : 0.f; + strided_batch_sgemm(false, + true, + head_dim, + qkv_seq_len, + qkv_seq_len, + alpha, + output_grads, + head_dim * attn_batches, + head_dim, + attention_maps, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads + 2 * head_dim, + lead_dim, + batch_stride, + attn_batches); + } + if (req[1] != kNullOp) { + const float beta = req[1] == kAddTo ? 1.f : 0.f; + strided_batch_sgemm(true, + false, + qkv_seq_len, + qkv_seq_len, + head_dim, + alpha, + queries_keys_values + 2 * head_dim, + lead_dim, + batch_stride, + output_grads, + head_dim * attn_batches, + head_dim, + beta, + attention_maps_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + attn_batches); + } +} + +void InterleavedMatMulEncDecQKCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) + return; + + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* queries = inputs[0].FlatTo2D(s).dptr_; + const float* keys_values = inputs[1].FlatTo2D(s).dptr_; + float* output = outputs[0].FlatTo2D(s).dptr_; + const index_t q_seq_len = inputs[0].shape_[0]; + const index_t sequences = inputs[0].shape_[1]; + const index_t output_lin_q_dim = inputs[0].shape_[2]; + const index_t kv_seq_len = inputs[1].shape_[0]; + const index_t embed_dim = output_lin_q_dim; + const index_t head_dim = embed_dim / params.heads; + const index_t attn_batches = params.heads * sequences; + const index_t lead_dim_q = attn_batches * head_dim; + const index_t lead_dim_kv = attn_batches * 2 * head_dim; + const index_t batch_stride_q = head_dim; + const index_t batch_stride_kv = head_dim * 2; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + const float scale = 1.f / sqrt(static_cast(head_dim)); + + strided_batch_sgemm(true, + false, + kv_seq_len, + q_seq_len, + head_dim, + scale, + keys_values, + lead_dim_kv, + batch_stride_kv, + queries, + lead_dim_q, + batch_stride_q, + beta, + output, + kv_seq_len, + kv_seq_len * q_seq_len, + attn_batches); +} + +void BackwardInterleavedMatMulEncDecQKCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) + return; + + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* output_grads = inputs[0].FlatTo2D(s).dptr_; + const float* queries = inputs[1].FlatTo2D(s).dptr_; + const float* keys_values = inputs[2].FlatTo2D(s).dptr_; + float* queries_grads = outputs[0].FlatTo2D(s).dptr_; + float* keys_values_grads = outputs[1].FlatTo2D(s).dptr_; + const index_t q_seq_len = inputs[1].shape_[0]; + const index_t sequences = inputs[1].shape_[1]; + const index_t output_lin_q_dim = inputs[1].shape_[2]; + const index_t kv_seq_len = inputs[2].shape_[0]; + const index_t embed_dim = output_lin_q_dim; + const index_t head_dim = embed_dim / params.heads; + const index_t attn_batches = params.heads * sequences; + const index_t lead_dim_q = attn_batches * head_dim; + const index_t lead_dim_kv = attn_batches * 2 * head_dim; + const index_t batch_stride_q = head_dim; + const index_t batch_stride_kv = head_dim * 2; + const float scale = 1.f / sqrt(static_cast(head_dim)); + + if (req[0] != kNullOp) { + const float beta = req[0] == kAddTo ? 1.f : 0.f; + strided_batch_sgemm(false, + false, + head_dim, + q_seq_len, + kv_seq_len, + scale, + keys_values, + lead_dim_kv, + batch_stride_kv, + output_grads, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + queries_grads, + lead_dim_q, + batch_stride_q, + attn_batches); + } + if (req[1] != kNullOp) { + if (req[1] == kWriteTo) { + memset(keys_values_grads, 0, outputs[1].shape_.Size() * sizeof (float)); + } + const float beta = req[1] == kAddTo ? 1.f : 0.f; + strided_batch_sgemm(false, + true, + head_dim, + kv_seq_len, + q_seq_len, + scale, + queries, + lead_dim_q, + batch_stride_q, + output_grads, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + keys_values_grads, + lead_dim_kv, + batch_stride_kv, + attn_batches); + } +} + +void InterleavedMatMulEncDecValAttCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) + return; + + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* keys_values = inputs[0].FlatTo2D(s).dptr_; + const float* attention_maps = inputs[1].FlatTo2D(s).dptr_; + float* output = outputs[0].FlatTo2D(s).dptr_; + const index_t kv_seq_len = inputs[0].shape_[0]; + const index_t output_lin_kv_dim = inputs[0].shape_[2]; + const index_t attn_batches = inputs[1].shape_[0]; + const index_t q_seq_len = inputs[1].shape_[1]; + const index_t embed_dim = output_lin_kv_dim / 2; + const index_t head_dim = embed_dim / params.heads; + const index_t lead_dim_kv = attn_batches * head_dim * 2; + const index_t batch_stride_kv = 2 * head_dim; + const float alpha = 1.f; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + + strided_batch_sgemm(false, + false, + head_dim, + q_seq_len, + kv_seq_len, + alpha, + keys_values + head_dim, + lead_dim_kv, + batch_stride_kv, + attention_maps, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + output, + head_dim * attn_batches, + head_dim, + attn_batches); +} + +void BackwardInterleavedMatMulEncDecValAttCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* output_grads = inputs[0].FlatTo2D(s).dptr_; + const float* keys_values = inputs[1].FlatTo2D(s).dptr_; + const float* attention_maps = inputs[2].FlatTo2D(s).dptr_; + float* keys_values_grads = outputs[0].FlatTo2D(s).dptr_; + float* attention_maps_grads = outputs[1].FlatTo2D(s).dptr_; + const index_t kv_seq_len = inputs[1].shape_[0]; + const index_t output_lin_kv_dim = inputs[1].shape_[2]; + const index_t attn_batches = inputs[2].shape_[0]; + const index_t q_seq_len = inputs[2].shape_[1]; + const index_t embed_dim = output_lin_kv_dim / 2; + const index_t head_dim = embed_dim / params.heads; + const index_t lead_dim_kv = attn_batches * head_dim * 2; + const index_t batch_stride_kv = 2 * head_dim; + const float alpha = 1.f; + + if (req[0] != kNullOp) { + if (req[0] == kWriteTo) { + memset(keys_values_grads, 0, outputs[0].shape_.Size() * sizeof (float)); + } + const float beta = req[0] == kAddTo ? 1.f : 0.f; + strided_batch_sgemm(false, + true, + head_dim, + kv_seq_len, + q_seq_len, + alpha, + output_grads, + head_dim * attn_batches, + head_dim, + attention_maps, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + keys_values_grads + head_dim, + lead_dim_kv, + batch_stride_kv, + attn_batches); + } + if (req[1] != kNullOp) { + const float beta = req[1] == kAddTo ? 1.f : 0.f; + strided_batch_sgemm(true, + false, + kv_seq_len, + q_seq_len, + head_dim, + alpha, + keys_values + head_dim, + lead_dim_kv, + batch_stride_kv, + output_grads, + head_dim * attn_batches, + head_dim, + beta, + attention_maps_grads, + kv_seq_len, + kv_seq_len * q_seq_len, + attn_batches); + } } NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_qk) @@ -243,8 +664,6 @@ q_proj = mx.nd.contrib.div_sqrt_dim(q_proj) k_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3)) k_proj = mx.nd.reshap(k_proj, shape=(-1, 0, 0), reverse=True) output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True) - -This Op is GPU only )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -289,8 +708,6 @@ output = mx.nd.batch_dot(attention, v_proj, transpose_b=True) output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True) output = mx.nd.transpose(output, axes=(0, 2, 1, 3)) output = mx.nd.reshape(output, shape=(0, 0, -1)) - -This Op is GPU only )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -303,6 +720,7 @@ This Op is GPU only }) .set_attr("FInferShape", InterleavedMatMulSelfAttValAttShape) .set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FCompute", InterleavedMatMulSelfAttValAttCPU) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_valatt"}) .add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") @@ -313,7 +731,8 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_valatt) .set_num_inputs(3) .set_num_outputs(2) .set_attr("TIsBackward", true) -.set_attr_parser(ParamParser); +.set_attr_parser(ParamParser) +.set_attr("FCompute", BackwardInterleavedMatMulSelfAttValAttCPU); NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_qk) .describe(R"code(Compute the matrix multiplication between the projections of @@ -333,8 +752,6 @@ tmp = mx.nd.reshape(keys_values, shape=(0, 0, num_heads, 2, -1)) k_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3)) k_proj = mx.nd.reshap(k_proj, shape=(-1, 0, 0), reverse=True) output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True) - -This Op is GPU only )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -347,6 +764,7 @@ This Op is GPU only }) .set_attr("FInferShape", InterleavedMatMulEncDecQKShape) .set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FCompute", InterleavedMatMulEncDecQKCPU) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_qk"}) .add_argument("queries", "NDArray-or-Symbol", "Queries") @@ -357,7 +775,8 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_qk) .set_num_inputs(3) .set_num_outputs(2) .set_attr("TIsBackward", true) -.set_attr_parser(ParamParser); +.set_attr_parser(ParamParser) +.set_attr("FCompute", BackwardInterleavedMatMulEncDecQKCPU); NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_valatt) .describe(R"code(Compute the matrix multiplication between the projections of @@ -379,8 +798,6 @@ output = mx.nd.batch_dot(attention, v_proj, transpose_b=True) output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True) output = mx.nd.transpose(output, axes=(0, 2, 1, 3)) output = mx.nd.reshape(output, shape=(0, 0, -1)) - -This Op is GPU only )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -393,6 +810,7 @@ This Op is GPU only }) .set_attr("FInferShape", InterleavedMatMulEncDecValAttShape) .set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FCompute", InterleavedMatMulEncDecValAttCPU) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_valatt"}) .add_argument("keys_values", "NDArray-or-Symbol", "Keys and values interleaved") @@ -403,7 +821,8 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_valatt) .set_num_inputs(3) .set_num_outputs(2) .set_attr("TIsBackward", true) -.set_attr_parser(ParamParser); +.set_attr_parser(ParamParser) +.set_attr("FCompute", BackwardInterleavedMatMulEncDecValAttCPU); // relu diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7b0404d8abb7..b3478b158446 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9382,6 +9382,330 @@ def check_random_uniform(): hight = 1 assertRaises(MXNetError, mx.nd.random_uniform, alpha, beta, shape) +def check_multihead_attention_selfatt(dtype): + def convert_weight(F, q_weight, k_weight, v_weight, num_heads): + q_weight = F.reshape(q_weight, shape=(num_heads, -1, 0), reverse=True) + k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) + v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) + all_weights = F.concat(q_weight, k_weight, v_weight, dim=-2) + all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) + return all_weights + + def convert_bias(F, q_bias, k_bias, v_bias, num_heads): + q_bias = F.reshape(q_bias, shape=(num_heads, -1)) + k_bias = F.reshape(k_bias, shape=(num_heads, -1)) + v_bias = F.reshape(v_bias, shape=(num_heads, -1)) + all_bias = F.stack(q_bias, k_bias, v_bias, axis=1) + all_bias = F.reshape(all_bias, shape=(-1,)) + return all_bias + + batch_size = 2 + qkv_length = 7 # length of a sequence + qkv_dim = 9 # dimension of encoding + num_heads = 3 # number of attention head + head_dim = 5 # head size + out_dim = 13 * num_heads + qkv_units = num_heads * head_dim + + arg_params = { + 'qkv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), + 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), + } + + qkv = mx.sym.Variable('qkv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + qkv_weight = convert_weight(mx.sym, q_weight, k_weight, v_weight, num_heads) + qkv_bias = convert_bias(mx.sym, q_bias, k_bias, v_bias, num_heads) + qkv = mx.sym.transpose(qkv, axes=(1, 0, 2)) + qkv_proj = mx.sym.FullyConnected(qkv, weight=qkv_weight, bias=qkv_bias, flatten=False, + num_hidden=qkv_units * 3, no_bias=False) + att_score = mx.sym.contrib.interleaved_matmul_selfatt_qk( + qkv_proj, heads=num_heads) + att_score = att_score + sonde + weighted_value = mx.sym.contrib.interleaved_matmul_selfatt_valatt( + qkv_proj, att_score, heads=num_heads) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.transpose(output, axes=(1, 0, 2)) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=default_context(), + qkv=(batch_size, qkv_length, qkv_dim), + q_weight=(qkv_units, qkv_dim), + q_bias=(qkv_units,), + k_weight=(qkv_units, qkv_dim), + k_bias=(qkv_units,), + v_weight=(qkv_units, qkv_dim), + v_bias=(qkv_units,), + type_dict={'qkv': dtype, + 'q_weight': dtype, + 'k_weight': dtype, + 'v_weight': dtype, + 'q_bias': dtype, + 'k_bias': dtype, + 'v_bias': dtype, + 'sonde': dtype}, + grad_req='write', force_rebind=True) + output_shape = executor.outputs[0].shape + output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1 + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_opti = executor.outputs[0].asnumpy() + att_score_opti = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), + mx.nd.zeros(att_score_opti.shape, dtype=dtype)]) + grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} + qkv = mx.sym.Variable('qkv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + + q = mx.sym.FullyConnected(qkv, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + k = mx.sym.FullyConnected(qkv, weight=k_weight, bias=k_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + v = mx.sym.FullyConnected(qkv, weight=v_weight, bias=v_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) + q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) + q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) + k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) + k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) + k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) + q = mx.sym.contrib.div_sqrt_dim(q) + att_score = mx.sym.batch_dot(q, k, transpose_b=True) + att_score = att_score + sonde + v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) + v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) + v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) + weighted_value = mx.sym.batch_dot(att_score, v) + weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), + reverse=True) + weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) + weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=default_context(), + qkv=(batch_size, qkv_length, qkv_dim), + type_dict={'qkv': dtype}, + grad_req='write', force_rebind=True) + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_orig = executor.outputs[0].asnumpy() + att_score_orig = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), + mx.nd.zeros(att_score_orig.shape, dtype=dtype)]) + grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} + assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) + assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) + + for k in grads_opti.keys(): + assert(grads_orig[k].dtype == grads_opti[k].dtype) + assert(grads_orig[k].shape == grads_opti[k].shape) + assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) + + +@with_seed(12345) +def test_multihead_attention_selfatt(): + dtypes = ['float32'] + if default_context().device_type == 'gpu': + dtypes += ['float16'] + + for dtype in dtypes: + check_multihead_attention_selfatt(dtype=dtype) + +def check_multihead_attention_encdec(dtype): + def convert_weight(F, k_weight, v_weight, num_heads): + k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) + v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) + all_weights = F.concat(k_weight, v_weight, dim=-2) + all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) + return all_weights + + def convert_bias(F, k_bias, v_bias, num_heads): + k_bias = F.reshape(k_bias, shape=(num_heads, -1)) + v_bias = F.reshape(v_bias, shape=(num_heads, -1)) + all_bias = F.stack(k_bias, v_bias, axis=1) + all_bias = F.reshape(all_bias, shape=(-1,)) + return all_bias + + batch_size = 2 + qkv_length = 7 # length of a sequence + qkv_dim = 9 # dimension of encoding + num_heads = 3 # number of attention head + head_dim = 5 # head size + out_dim = 13 * num_heads + qkv_units = num_heads * head_dim + + arg_params = { + 'q': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'kv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), + 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), + } + + q = mx.sym.Variable('q') + kv = mx.sym.Variable('kv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + kv_weight = convert_weight(mx.sym, k_weight, v_weight, num_heads) + kv_bias = convert_bias(mx.sym, k_bias, v_bias, num_heads) + kv = mx.sym.transpose(kv, axes=(1, 0, 2)) + kv_proj = mx.sym.FullyConnected(kv, weight=kv_weight, bias=kv_bias, flatten=False, + num_hidden=qkv_units * 2, no_bias=False) + q = mx.sym.transpose(q, axes=(1, 0, 2)) + q_proj = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + att_score = mx.sym.contrib.interleaved_matmul_encdec_qk( + q_proj, kv_proj, heads=num_heads) + att_score = att_score + sonde + weighted_value = mx.sym.contrib.interleaved_matmul_encdec_valatt( + kv_proj, att_score, heads=num_heads) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.transpose(output, axes=(1, 0, 2)) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=default_context(), + q=(batch_size, qkv_length, qkv_dim), + kv=(batch_size, qkv_length, qkv_dim), + q_weight=(qkv_units, qkv_dim), + q_bias=(qkv_units,), + k_weight=(qkv_units, qkv_dim), + k_bias=(qkv_units,), + v_weight=(qkv_units, qkv_dim), + v_bias=(qkv_units,), + out_weight=(out_dim, qkv_units), + out_bias=(out_dim,), + type_dict={'q': dtype, + 'kv': dtype, + 'q_weight': dtype, + 'q_bias': dtype, + 'k_weight': dtype, + 'k_bias': dtype, + 'v_weight': dtype, + 'v_bias': dtype, + 'out_weight': dtype, + 'out_bias': dtype, + }, + grad_req='write', force_rebind=True) + output_shape = executor.outputs[0].shape + output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1 + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_opti = executor.outputs[0].asnumpy() + att_score_opti = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), mx.nd.zeros(att_score_opti.shape, dtype=dtype)]) + + grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} + + q = mx.sym.Variable('q') + kv = mx.sym.Variable('kv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + + q = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + k = mx.sym.FullyConnected(kv, weight=k_weight, bias=k_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + v = mx.sym.FullyConnected(kv, weight=v_weight, bias=v_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) + q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) + q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) + k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) + k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) + k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) + q = mx.sym.contrib.div_sqrt_dim(q) + att_score = mx.sym.batch_dot(q, k, transpose_b=True) + att_score = att_score + sonde + v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) + v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) + v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) + weighted_value = mx.sym.batch_dot(att_score, v) + weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), + reverse=True) + weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) + weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=default_context(), + q=(batch_size, qkv_length, qkv_dim), + kv=(batch_size, qkv_length, qkv_dim), + type_dict={'q': dtype, + 'kv': dtype}, + grad_req='write', force_rebind=True) + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_orig = executor.outputs[0].asnumpy() + att_score_orig = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), mx.nd.zeros(att_score_orig.shape, dtype=dtype)]) + grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} + assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) + assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) + + for k in grads_opti.keys(): + assert(grads_orig[k].dtype == grads_opti[k].dtype) + assert(grads_orig[k].shape == grads_opti[k].shape) + assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) + +@with_seed(12345) +def test_multihead_attention_encdec(): + dtypes = ['float32'] + if default_context().device_type == 'gpu': + dtypes += ['float16'] + + for dtype in dtypes: + check_multihead_attention_encdec(dtype=dtype) if __name__ == '__main__': import nose From 2f0181ae0e8aadab8018bd2b5adb478aac67d702 Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Mon, 2 Dec 2019 23:07:38 +0800 Subject: [PATCH 4/9] coding style --- src/operator/contrib/transformer.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/operator/contrib/transformer.cc b/src/operator/contrib/transformer.cc index b135a80afcc6..5a815c3fe1fc 100644 --- a/src/operator/contrib/transformer.cc +++ b/src/operator/contrib/transformer.cc @@ -162,14 +162,14 @@ void strided_batch_sgemm(bool transA, bool transB, p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(), p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb); #else - for (int i = 0; i < batchCount; ++i) { - cblas_sgemm(CblasColMajor, - transA ? CblasTrans : CblasNoTrans, - transB ? CblasTrans : CblasNoTrans, - m, n, k, - alpha, pp_A[i], lda, - pp_B[i], ldb, beta, pp_C[i], ldc); - } + for (int i = 0; i < batchCount; ++i) { + cblas_sgemm(CblasColMajor, + transA ? CblasTrans : CblasNoTrans, + transB ? CblasTrans : CblasNoTrans, + m, n, k, + alpha, pp_A[i], lda, + pp_B[i], ldb, beta, pp_C[i], ldc); + } #endif } From 544bd25fe6ba24e35e5b50252d6f133452bad52a Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Sat, 21 Dec 2019 23:16:01 +0800 Subject: [PATCH 5/9] fix lint --- src/operator/contrib/transformer.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/operator/contrib/transformer.cc b/src/operator/contrib/transformer.cc index 5a815c3fe1fc..0ca8a3abef7b 100644 --- a/src/operator/contrib/transformer.cc +++ b/src/operator/contrib/transformer.cc @@ -128,7 +128,6 @@ void strided_batch_sgemm(bool transA, bool transB, index_t strideA, const float *b, index_t ldb, index_t strideB, float beta, float *c, index_t ldc, index_t strideC, int32_t batchCount) { - std::vector pp_A(batchCount, nullptr); std::vector pp_B(batchCount, nullptr); std::vector pp_C(batchCount, nullptr); @@ -362,10 +361,10 @@ void BackwardInterleavedMatMulSelfAttValAttCPU(const nnvm::NodeAttrs& attrs, const index_t batch_stride = 3 * head_dim; const float alpha = 1.f; if (req[0] != kNullOp) { - if (req[0] == kWriteTo) { memset(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof (float)); } + const float beta = req[0] == kAddTo ? 1.f : 0.f; strided_batch_sgemm(false, true, From 1c639eb10b00eac914e4e774f5caf5015cae653e Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Sun, 22 Dec 2019 22:15:44 +0800 Subject: [PATCH 6/9] use random seed in tests --- tests/python/unittest/test_operator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7bdede8aacef..72350de4d85f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9521,7 +9521,7 @@ def convert_bias(F, q_bias, k_bias, v_bias, num_heads): assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) -@with_seed(12345) +@with_seed() def test_multihead_attention_selfatt(): dtypes = ['float32'] if default_context().device_type == 'gpu': @@ -9689,7 +9689,7 @@ def convert_bias(F, k_bias, v_bias, num_heads): assert(grads_orig[k].shape == grads_opti[k].shape) assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) -@with_seed(12345) +@with_seed() def test_multihead_attention_encdec(): dtypes = ['float32'] if default_context().device_type == 'gpu': From fa9db43187515e2d1c9a89f99f9ad4a9009d418f Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Mon, 23 Dec 2019 10:16:21 +0800 Subject: [PATCH 7/9] remove ut in test_operator_gpu.py --- tests/python/gpu/test_operator_gpu.py | 317 -------------------------- 1 file changed, 317 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index e548217b9369..721eaaebab31 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2548,323 +2548,6 @@ def test_arange_like_dtype(): for v in out: assert v.dtype == t -@with_seed() -def check_multihead_attention_selfatt(dtype): - def convert_weight(F, q_weight, k_weight, v_weight, num_heads): - q_weight = F.reshape(q_weight, shape=(num_heads, -1, 0), reverse=True) - k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) - v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) - all_weights = F.concat(q_weight, k_weight, v_weight, dim=-2) - all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) - return all_weights - - def convert_bias(F, q_bias, k_bias, v_bias, num_heads): - q_bias = F.reshape(q_bias, shape=(num_heads, -1)) - k_bias = F.reshape(k_bias, shape=(num_heads, -1)) - v_bias = F.reshape(v_bias, shape=(num_heads, -1)) - all_bias = F.stack(q_bias, k_bias, v_bias, axis=1) - all_bias = F.reshape(all_bias, shape=(-1,)) - return all_bias - - batch_size = 2 - qkv_length = 7 # length of a sequence - qkv_dim = 9 # dimension of encoding - num_heads = 3 # number of attention head - head_dim = 5 # head size - out_dim = 13 * num_heads - qkv_units = num_heads * head_dim - - arg_params = { - 'qkv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), - 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), - 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), - 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), - 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), - } - - qkv = mx.sym.Variable('qkv') - sonde = mx.sym.Variable('sonde') - q_weight = mx.sym.Variable('q_weight') - k_weight = mx.sym.Variable('k_weight') - v_weight = mx.sym.Variable('v_weight') - q_bias = mx.sym.Variable('q_bias') - k_bias = mx.sym.Variable('k_bias') - v_bias = mx.sym.Variable('v_bias') - out_weight = mx.sym.Variable('out_weight') - out_bias = mx.sym.Variable('out_bias') - qkv_weight = convert_weight(mx.sym, q_weight, k_weight, v_weight, num_heads) - qkv_bias = convert_bias(mx.sym, q_bias, k_bias, v_bias, num_heads) - qkv = mx.sym.transpose(qkv, axes=(1, 0, 2)) - qkv_proj = mx.sym.FullyConnected(qkv, weight=qkv_weight, bias=qkv_bias, flatten=False, - num_hidden=qkv_units * 3, no_bias=False) - att_score = mx.sym.contrib.interleaved_matmul_selfatt_qk( - qkv_proj, heads=num_heads) - att_score = att_score + sonde - weighted_value = mx.sym.contrib.interleaved_matmul_selfatt_valatt( - qkv_proj, att_score, heads=num_heads) - output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, - num_hidden=out_dim, no_bias=False) - output = mx.sym.transpose(output, axes=(1, 0, 2)) - output = mx.sym.Group([output, att_score]) - executor = output.simple_bind(ctx=mx.gpu(0), - qkv=(batch_size, qkv_length, qkv_dim), - q_weight=(qkv_units, qkv_dim), - q_bias=(qkv_units,), - k_weight=(qkv_units, qkv_dim), - k_bias=(qkv_units,), - v_weight=(qkv_units, qkv_dim), - v_bias=(qkv_units,), - type_dict={'qkv': dtype, - 'q_weight': dtype, - 'k_weight': dtype, - 'v_weight': dtype, - 'q_bias': dtype, - 'k_bias': dtype, - 'v_bias': dtype, - 'sonde': dtype}, - grad_req='write', force_rebind=True) - output_shape = executor.outputs[0].shape - output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1 - executor.copy_params_from(arg_params, {}) - executor.arg_dict['sonde'][:] = 0. - executor.arg_dict['sonde'].wait_to_read() - executor.forward(is_train=True) - output_opti = executor.outputs[0].asnumpy() - att_score_opti = executor.outputs[1].asnumpy() - executor.backward([mx.nd.array(output_grads, dtype=dtype), - mx.nd.zeros(att_score_opti.shape, dtype=dtype)]) - grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} - qkv = mx.sym.Variable('qkv') - sonde = mx.sym.Variable('sonde') - q_weight = mx.sym.Variable('q_weight') - k_weight = mx.sym.Variable('k_weight') - v_weight = mx.sym.Variable('v_weight') - q_bias = mx.sym.Variable('q_bias') - k_bias = mx.sym.Variable('k_bias') - v_bias = mx.sym.Variable('v_bias') - out_weight = mx.sym.Variable('out_weight') - out_bias = mx.sym.Variable('out_bias') - - q = mx.sym.FullyConnected(qkv, weight=q_weight, bias=q_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - k = mx.sym.FullyConnected(qkv, weight=k_weight, bias=k_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - v = mx.sym.FullyConnected(qkv, weight=v_weight, bias=v_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) - q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) - q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) - k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) - k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) - k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) - q = mx.sym.contrib.div_sqrt_dim(q) - att_score = mx.sym.batch_dot(q, k, transpose_b=True) - att_score = att_score + sonde - v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) - v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) - v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) - weighted_value = mx.sym.batch_dot(att_score, v) - weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), - reverse=True) - weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) - weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) - output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, - num_hidden=out_dim, no_bias=False) - output = mx.sym.Group([output, att_score]) - executor = output.simple_bind(ctx=mx.gpu(0), - qkv=(batch_size, qkv_length, qkv_dim), - type_dict={'qkv': dtype}, - grad_req='write', force_rebind=True) - executor.copy_params_from(arg_params, {}) - executor.arg_dict['sonde'][:] = 0. - executor.arg_dict['sonde'].wait_to_read() - executor.forward(is_train=True) - output_orig = executor.outputs[0].asnumpy() - att_score_orig = executor.outputs[1].asnumpy() - executor.backward([mx.nd.array(output_grads, dtype=dtype), - mx.nd.zeros(att_score_orig.shape, dtype=dtype)]) - grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} - assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) - assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) - - for k in grads_opti.keys(): - assert(grads_orig[k].dtype == grads_opti[k].dtype) - assert(grads_orig[k].shape == grads_opti[k].shape) - assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) - -@assert_raises_cuda_not_satisfied(min_version='9.1') -def test_multihead_attention_selfatt(): - for dtype in ['float16', 'float32']: - check_multihead_attention_selfatt(dtype=dtype) - -def check_multihead_attention_encdec(dtype): - def convert_weight(F, k_weight, v_weight, num_heads): - k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) - v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) - all_weights = F.concat(k_weight, v_weight, dim=-2) - all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) - return all_weights - - def convert_bias(F, k_bias, v_bias, num_heads): - k_bias = F.reshape(k_bias, shape=(num_heads, -1)) - v_bias = F.reshape(v_bias, shape=(num_heads, -1)) - all_bias = F.stack(k_bias, v_bias, axis=1) - all_bias = F.reshape(all_bias, shape=(-1,)) - return all_bias - - batch_size = 2 - qkv_length = 7 # length of a sequence - qkv_dim = 9 # dimension of encoding - num_heads = 3 # number of attention head - head_dim = 5 # head size - out_dim = 13 * num_heads - qkv_units = num_heads * head_dim - - arg_params = { - 'q': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'kv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), - 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), - 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), - 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), - 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), - } - - q = mx.sym.Variable('q') - kv = mx.sym.Variable('kv') - sonde = mx.sym.Variable('sonde') - q_weight = mx.sym.Variable('q_weight') - k_weight = mx.sym.Variable('k_weight') - v_weight = mx.sym.Variable('v_weight') - q_bias = mx.sym.Variable('q_bias') - k_bias = mx.sym.Variable('k_bias') - v_bias = mx.sym.Variable('v_bias') - out_weight = mx.sym.Variable('out_weight') - out_bias = mx.sym.Variable('out_bias') - kv_weight = convert_weight(mx.sym, k_weight, v_weight, num_heads) - kv_bias = convert_bias(mx.sym, k_bias, v_bias, num_heads) - kv = mx.sym.transpose(kv, axes=(1, 0, 2)) - kv_proj = mx.sym.FullyConnected(kv, weight=kv_weight, bias=kv_bias, flatten=False, - num_hidden=qkv_units * 2, no_bias=False) - q = mx.sym.transpose(q, axes=(1, 0, 2)) - q_proj = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - att_score = mx.sym.contrib.interleaved_matmul_encdec_qk( - q_proj, kv_proj, heads=num_heads) - att_score = att_score + sonde - weighted_value = mx.sym.contrib.interleaved_matmul_encdec_valatt( - kv_proj, att_score, heads=num_heads) - output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, - num_hidden=out_dim, no_bias=False) - output = mx.sym.transpose(output, axes=(1, 0, 2)) - output = mx.sym.Group([output, att_score]) - executor = output.simple_bind(ctx=mx.gpu(0), - q=(batch_size, qkv_length, qkv_dim), - kv=(batch_size, qkv_length, qkv_dim), - q_weight=(qkv_units, qkv_dim), - q_bias=(qkv_units,), - k_weight=(qkv_units, qkv_dim), - k_bias=(qkv_units,), - v_weight=(qkv_units, qkv_dim), - v_bias=(qkv_units,), - out_weight=(out_dim, qkv_units), - out_bias=(out_dim,), - type_dict={'q': dtype, - 'kv': dtype, - 'q_weight': dtype, - 'q_bias': dtype, - 'k_weight': dtype, - 'k_bias': dtype, - 'v_weight': dtype, - 'v_bias': dtype, - 'out_weight': dtype, - 'out_bias': dtype, - }, - grad_req='write', force_rebind=True) - output_shape = executor.outputs[0].shape - output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1 - executor.copy_params_from(arg_params, {}) - executor.arg_dict['sonde'][:] = 0. - executor.arg_dict['sonde'].wait_to_read() - executor.forward(is_train=True) - output_opti = executor.outputs[0].asnumpy() - att_score_opti = executor.outputs[1].asnumpy() - executor.backward([mx.nd.array(output_grads, dtype=dtype), mx.nd.zeros(att_score_opti.shape, dtype=dtype)]) - - grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} - - q = mx.sym.Variable('q') - kv = mx.sym.Variable('kv') - sonde = mx.sym.Variable('sonde') - q_weight = mx.sym.Variable('q_weight') - k_weight = mx.sym.Variable('k_weight') - v_weight = mx.sym.Variable('v_weight') - q_bias = mx.sym.Variable('q_bias') - k_bias = mx.sym.Variable('k_bias') - v_bias = mx.sym.Variable('v_bias') - out_weight = mx.sym.Variable('out_weight') - out_bias = mx.sym.Variable('out_bias') - - q = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - k = mx.sym.FullyConnected(kv, weight=k_weight, bias=k_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - v = mx.sym.FullyConnected(kv, weight=v_weight, bias=v_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) - q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) - q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) - k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) - k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) - k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) - q = mx.sym.contrib.div_sqrt_dim(q) - att_score = mx.sym.batch_dot(q, k, transpose_b=True) - att_score = att_score + sonde - v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) - v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) - v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) - weighted_value = mx.sym.batch_dot(att_score, v) - weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), - reverse=True) - weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) - weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) - output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, - num_hidden=out_dim, no_bias=False) - output = mx.sym.Group([output, att_score]) - executor = output.simple_bind(ctx=mx.gpu(0), - q=(batch_size, qkv_length, qkv_dim), - kv=(batch_size, qkv_length, qkv_dim), - type_dict={'q': dtype, - 'kv': dtype}, - grad_req='write', force_rebind=True) - executor.copy_params_from(arg_params, {}) - executor.arg_dict['sonde'][:] = 0. - executor.arg_dict['sonde'].wait_to_read() - executor.forward(is_train=True) - output_orig = executor.outputs[0].asnumpy() - att_score_orig = executor.outputs[1].asnumpy() - executor.backward([mx.nd.array(output_grads, dtype=dtype), mx.nd.zeros(att_score_orig.shape, dtype=dtype)]) - grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} - assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) - assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) - - for k in grads_opti.keys(): - assert(grads_orig[k].dtype == grads_opti[k].dtype) - assert(grads_orig[k].shape == grads_opti[k].shape) - assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) - -@assert_raises_cuda_not_satisfied(min_version='9.1') -def test_multihead_attention_encdec(): - for dtype in ['float16', 'float32']: - check_multihead_attention_encdec(dtype=dtype) - if __name__ == '__main__': import nose nose.runmodule() From 0037639ee25bd6dd422f9fcaeba1d6f1bb14ef2d Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Sat, 4 Jan 2020 00:28:48 +0800 Subject: [PATCH 8/9] coding style --- src/operator/contrib/transformer.cc | 32 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/operator/contrib/transformer.cc b/src/operator/contrib/transformer.cc index 0ca8a3abef7b..58826a2d96a8 100644 --- a/src/operator/contrib/transformer.cc +++ b/src/operator/contrib/transformer.cc @@ -251,22 +251,22 @@ void BackwardInterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs, } strided_batch_sgemm(false, - false, - head_dim, - qkv_seq_len, - qkv_seq_len, - scale, - queries_keys_values + head_dim, - lead_dim, - batch_stride, - output_grads, - qkv_seq_len, - qkv_seq_len * qkv_seq_len, - beta, - queries_keys_values_grads, - lead_dim, - batch_stride, - attn_batches); + false, + head_dim, + qkv_seq_len, + qkv_seq_len, + scale, + queries_keys_values + head_dim, + lead_dim, + batch_stride, + output_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads, + lead_dim, + batch_stride, + attn_batches); strided_batch_sgemm(false, true, From 2f22f686f3d65d3edde59885cb49950a3bc63eac Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Sat, 4 Jan 2020 07:28:52 +0800 Subject: [PATCH 9/9] retrigger ci