Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

No tensor cores for fp32 interleaved attention, remove div by 8 restiction (#17994) #18085

Merged
merged 1 commit into from
Apr 16, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions src/operator/contrib/transformer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void CublasStridedBatchedGemm(mshadow::Stream<gpu>* s, bool transA, bool transB,
float alpha, const DType* a, int32_t lda, int32_t strideA,
const DType *b, int32_t ldb, int32_t strideB, float beta,
DType *c, int32_t ldc, int32_t strideC, int32_t batchCount,
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) {
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT) {
#if CUDA_VERSION >= 9010
using namespace mxnet::common::cuda;
CHECK_EQ(s->blas_handle_ownership_, mshadow::Stream<gpu>::OwnHandle)
Expand Down Expand Up @@ -142,9 +142,9 @@ void gemm_switch_fp32accum(mshadow::Stream<gpu>* s, bool transA, bool transB,
float alpha, const DType *a, int32_t lda,
int32_t strideA, const DType *b, int32_t ldb,
int32_t strideB, float beta, DType *c, int32_t ldc,
int32_t strideC, int32_t batchCount) {
int32_t strideC, int32_t batchCount, bool using_fp16) {
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
if (using_fp16) {
CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb,
strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} else {
Expand Down Expand Up @@ -175,6 +175,7 @@ void InterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs,
const int32_t batch_stride = 3 * head_dim;
const float beta = req[0] == kAddTo ? 1.f : 0.f;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] == kNullOp)
return;
Expand All @@ -196,7 +197,8 @@ void InterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs,
output,
qkv_seq_len,
qkv_seq_len * qkv_seq_len,
attn_batches);
attn_batches,
using_fp16);
})
}

Expand All @@ -220,7 +222,8 @@ void BackwardInterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs,
const int32_t lead_dim = attn_batches * 3 * head_dim;
const int32_t batch_stride = 3 * head_dim;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const float beta = req[0] == kAddTo ? 1.f : 0.f;
const float beta = req[0] == kAddTo ? 1.f : 0.f;
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] == kNullOp)
return;
Expand All @@ -247,7 +250,8 @@ void BackwardInterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs,
queries_keys_values_grads,
lead_dim,
batch_stride,
attn_batches);
attn_batches,
using_fp16);
gemm_switch_fp32accum(s,
false,
true,
Expand All @@ -265,7 +269,8 @@ void BackwardInterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs,
queries_keys_values_grads + head_dim,
lead_dim,
batch_stride,
attn_batches);
attn_batches,
using_fp16);
})
}

Expand All @@ -290,6 +295,7 @@ void InterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs,
const int32_t batch_stride = 3 * head_dim;
const float alpha = 1.f;
const float beta = req[0] == kAddTo ? 1.f : 0.f;
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] == kNullOp)
return;
Expand All @@ -311,7 +317,8 @@ void InterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs,
output,
head_dim * attn_batches,
head_dim,
attn_batches);
attn_batches,
using_fp16);
})
}

Expand All @@ -337,6 +344,8 @@ void BackwardInterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs,
const int32_t lead_dim = attn_batches * 3 * head_dim;
const int32_t batch_stride = 3 * head_dim;
const float alpha = 1.f;
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] != kNullOp) {
if (req[0] == kWriteTo) {
cudaMemsetAsync(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof(DType),
Expand All @@ -360,7 +369,8 @@ void BackwardInterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs,
queries_keys_values_grads + 2 * head_dim,
lead_dim,
batch_stride,
attn_batches);
attn_batches,
using_fp16);
}
if (req[1] != kNullOp) {
const float beta = req[1] == kAddTo ? 1.f : 0.f;
Expand All @@ -381,7 +391,8 @@ void BackwardInterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs,
attention_maps_grads,
qkv_seq_len,
qkv_seq_len * qkv_seq_len,
attn_batches);
attn_batches,
using_fp16);
}
})
}
Expand Down Expand Up @@ -412,6 +423,7 @@ void InterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs,
const int32_t batch_stride_kv = head_dim * 2;
const float beta = req[0] == kAddTo ? 1.f : 0.f;
const float scale = 1.f / sqrt(static_cast<float>(head_dim));
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] == kNullOp)
return;
Expand All @@ -433,7 +445,8 @@ void InterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs,
output,
kv_seq_len,
kv_seq_len * q_seq_len,
attn_batches);
attn_batches,
using_fp16);
})
}

Expand Down Expand Up @@ -463,6 +476,7 @@ void BackwardInterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs,
const int32_t batch_stride_q = head_dim;
const int32_t batch_stride_kv = head_dim * 2;
const float scale = 1.f / sqrt(static_cast<float>(head_dim));
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] != kNullOp) {
const float beta = req[0] == kAddTo ? 1.f : 0.f;
Expand All @@ -483,7 +497,8 @@ void BackwardInterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs,
queries_grads,
lead_dim_q,
batch_stride_q,
attn_batches);
attn_batches,
using_fp16);
}
if (req[1] != kNullOp) {
if (req[1] == kWriteTo) {
Expand All @@ -508,7 +523,8 @@ void BackwardInterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs,
keys_values_grads,
lead_dim_kv,
batch_stride_kv,
attn_batches);
attn_batches,
using_fp16);
}
})
}
Expand All @@ -535,6 +551,7 @@ void InterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs,
const int32_t batch_stride_kv = 2 * head_dim;
const float alpha = 1.f;
const float beta = req[0] == kAddTo ? 1.f : 0.f;
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] == kNullOp)
return;
Expand All @@ -556,7 +573,8 @@ void InterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs,
output,
head_dim * attn_batches,
head_dim,
attn_batches);
attn_batches,
using_fp16);
})
}

Expand All @@ -583,6 +601,7 @@ void BackwardInterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs,
const int32_t lead_dim_kv = attn_batches * head_dim * 2;
const int32_t batch_stride_kv = 2 * head_dim;
const float alpha = 1.f;
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] != kNullOp) {
if (req[0] == kWriteTo) {
Expand All @@ -607,7 +626,8 @@ void BackwardInterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs,
keys_values_grads + head_dim,
lead_dim_kv,
batch_stride_kv,
attn_batches);
attn_batches,
using_fp16);
}
if (req[1] != kNullOp) {
const float beta = req[1] == kAddTo ? 1.f : 0.f;
Expand All @@ -628,7 +648,8 @@ void BackwardInterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs,
attention_maps_grads,
kv_seq_len,
kv_seq_len * q_seq_len,
attn_batches);
attn_batches,
using_fp16);
}
})
}
Expand Down