Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Interleaved MHA for CPU path #17138

Merged
merged 14 commits into from
Jan 3, 2020
Prev Previous commit
Next Next commit
qk: fake backward
  • Loading branch information
TaoLv committed Nov 30, 2019
commit 29c0623869d7f44e67a5fd647b6808b48d45f3fa
39 changes: 26 additions & 13 deletions src/operator/contrib/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ void strided_batch_sgemm(bool transA, bool transB,
index_t strideA, const float *b, index_t ldb,
index_t strideB, float beta, float *c, index_t ldc,
index_t strideC, int32_t batchCount) {

std::vector<const float*> pp_A(batchCount, nullptr);
std::vector<const float*> pp_B(batchCount, nullptr);
std::vector<float*> pp_C(batchCount, nullptr);

for (int i = 0; i < batchCount; i++) {
pp_A[i] = a + i * strideA;
pp_B[i] = b + i * strideB;
pp_C[i] = c + i * strideC;
}

#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
const int GROUP_SIZE = 1;
MKL_INT p_m[GROUP_SIZE] = {m};
Expand All @@ -147,23 +158,17 @@ void strided_batch_sgemm(bool transA, bool transB,
CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};

std::vector<const float*> pp_A(batchCount, nullptr);
std::vector<const float*> pp_B(batchCount, nullptr);
std::vector<float*> pp_C(batchCount, nullptr);

for (int i = 0; i < batchCount; i++) {
pp_A[i] = a + i * strideA;
pp_B[i] = b + i * strideB;
pp_C[i] = c + i * strideC;
}

cblas_sgemm_batch(CblasRowMajor, p_transa, p_transb,
p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
#else
for (int i = 0; i < batchCount; ++i) {
cblas_sgemm(CblasRowMajor, transa, transb, m, n, k, alpha,
pp_A[i], lda, pp_B[i], ldb, beta, pp_C[i], ldc);
cblas_sgemm(CblasRowMajor,
transA ? CblasTrans : CblasNoTrans,
transB ? CblasTrans : CblasNoTrans,
m, n, k,
alpha, pp_A[i], lda,
pp_B[i], ldb, beta, pp_C[i], ldc);
}
#endif
}
Expand Down Expand Up @@ -215,6 +220,13 @@ void InterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs,
attn_batches);
}

void BackwardInterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
}

NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_qk)
.describe(R"code(Compute the matrix multiplication between the projections of
queries and keys in multihead attention use as self attention.
Expand Down Expand Up @@ -255,7 +267,8 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_qk)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<InterleavedMatMulParam>);
.set_attr_parser(ParamParser<InterleavedMatMulParam>)
.set_attr<FCompute>("FCompute<cpu>", BackwardInterleavedMatMulSelfAttQKCPU);

NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_valatt)
.describe(R"code(Compute the matrix multiplication between the projections of
Expand Down