Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Perf] Add CublasLt extern support for better Igemm performance #4550

Merged
merged 6 commits into from
Dec 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions python/tvm/contrib/cublaslt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""External function interface to cuBLASlt libraries."""
from __future__ import absolute_import as _abs

from .. import api as _api
from .. import intrin as _intrin

def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS

Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs

Returns
-------
C : Tensor
The result tensor.
"""
if n == 0:
n = lhs.shape[1] if transa else lhs.shape[0]
if m == 0:
m = rhs.shape[0] if transb else rhs.shape[1]
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublaslt.matmul",
ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
111 changes: 109 additions & 2 deletions src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,98 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s
}
}

int roundoff(int v, int d) {
return (v + d - 1) / d * d;
}

#if CUDART_VERSION >= 10010
inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) {
DLTensor *A = args[0];
DLTensor *B = args[1];
DLTensor *C = args[2];
bool transa = args[3];
bool transb = args[4];
// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed(A) ? !transa : transa;
transb = IsInPlaceTransposed(B) ? !transb : transb;
int M = ColumnCount(B, transb);
int N = RowCount(A, transa);
int K = ColumnCount(A, transa);
int N_out = ColumnCount(C, false);
int m = M;
int n = m;
int k = m;
int lda = M * K / (roundoff(K, 32) / 32);
int ldb = K * N / (roundoff(K, 32) / 32);
int ldc = M * N_out / (roundoff(N_out, 32) / 32);
CHECK_EQ(A->ndim, 2);
CHECK_EQ(B->ndim, 2);
CHECK_EQ(C->ndim, 2);

CHECK_EQ(ElementStride(A), 1);
CHECK_EQ(ElementStride(B), 1);
CHECK_EQ(ElementStride(C), 1);

CHECK(TypeEqual(A->dtype, B->dtype));
CHECK(TypeMatch(A->dtype, kDLInt, 8));
CHECK(TypeMatch(C->dtype, kDLInt, 32));

CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type";
int32_t alpha = args.size() > 5 ? args[5] : 1;
int32_t beta = args.size() > 6 ? args[6] : 0;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
auto A_data = reinterpret_cast<void*>(static_cast<char*>(A->data) + A->byte_offset);
auto B_data = reinterpret_cast<void*>(static_cast<char*>(B->data) + B->byte_offset);
auto C_data = reinterpret_cast<void*>(static_cast<char*>(C->data) + C->byte_offset);

cublasOperation_t opTranspose = CUBLAS_OP_T;
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C;
cublasLtMatmulDesc_t operationDesc = nullptr;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(opTranspose)));
cublasOperation_t opTransA = BooleanToTranspose(transa);
cublasOperation_t opTransB = BooleanToTranspose(transb);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTransA, sizeof(opTransA)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTransB, sizeof(opTransB)));
// Create descriptors for the original matrices
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(
&Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k ,
opTransA == CUBLAS_OP_N ? k : m, lda));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(
&Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n ,
opTransB == CUBLAS_OP_N ? n : k, ldb));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));

CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32)));

CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl,
operationDesc,
&alpha,
B_data,
Adesc,
A_data,
Bdesc,
&beta,
C_data,
Cdesc,
C_data,
Cdesc,
NULL,
NULL,
0,
0));
}
#endif

inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) {
DLTensor *A = args[0];
DLTensor *B = args[1];
Expand Down Expand Up @@ -342,12 +434,27 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul")
}
});

TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul")
#if CUDART_VERSION >= 10010
TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul")
masahi marked this conversation as resolved.
Show resolved Hide resolved
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
DLTensor* C = args[2];

CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();

TryEnableTensorCore(entry_ptr->handle);

CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n";
cublasLtHandle_t ltHandle;
CHECK_CUBLAS_ERROR(cublasLtCreate(&ltHandle));
CallLtIgemm(args, ret, ltHandle);
CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle));
});
#endif // CUDART_VERSION >= 10010

TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
DLTensor* C = args[2];

CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();

Expand Down
6 changes: 6 additions & 0 deletions src/runtime/contrib/cublas/cublas_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
#include <dmlc/logging.h>
#include <dlpack/dlpack.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <cstdint>
#if CUDART_VERSION >= 10010
#include <cublasLt.h>
#endif // CUDART_VERSION >= 10010

namespace tvm {
namespace contrib {
Expand Down
65 changes: 64 additions & 1 deletion tests/python/contrib/test_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tvm
import numpy as np
from tvm.contrib import cublas
from tvm.contrib import cublaslt

def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5):
n = 1024
Expand Down Expand Up @@ -44,6 +45,64 @@ def verify(target="cuda"):
c.asnumpy(), np.dot(a.asnumpy().astype(C.dtype), b.asnumpy().astype(C.dtype)), rtol=rtol)
verify()

def roundoff(v, d):
return int(np.floor((v + d - 1) / d) * d)

def verify_matmul_add_igemm(in_dtype, out_dtype, rtol=1e-5):
n = 1024
l = 1024
m = 1024
L = roundoff(l, 32)
N = roundoff(n, 8)
N_out = roundoff(n, 32)

A = tvm.placeholder((N, L), name='A', dtype=in_dtype)
B = tvm.placeholder((m, L), name='B', dtype=in_dtype)
# C has CUBLASLT_ORDER_COL32 layout, thus a different shape
C = cublaslt.matmul(A, B, False, True, m, N_out, dtype=out_dtype)
s = tvm.create_schedule(C.op)

def verify(target="cuda"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cublaslt.matmul", True):
print("skip because extern function is not available")
return
ctx = tvm.gpu(0)
f = tvm.build(s, [A, B, C], target)
a_old = np.random.uniform(0, 128, size=(n, l))
b_old = np.random.uniform(0, 128, size=(l, m))

# Transform a to become CUBLASLT_ORDER_COL4_4R2_8C layout
a_new = np.hstack((a_old.astype(A.dtype), np.zeros([n, L-l])))
a_new = np.vstack((a_new.astype(A.dtype), np.zeros([N-n, L])))
a_even = np.vsplit(a_new[::2], N / 8)
a_odd = np.vsplit(a_new[1::2], N / 8)
a_new = [None]*(len(a_even) + len(a_odd))
a_new[::2] = a_even
a_new[1::2] = a_odd
a_new = np.vstack(a_new)
a_new = np.vstack(np.vstack(np.vstack(np.hsplit(i, 8)).reshape([4, 32]) for i in np.vsplit(j, N/4)) for j in np.hsplit(a_new, L/32))
a_new = a_new.reshape([N, L])
# Transform b to become CUBLASLT_ORDER_COL32 layout
b_new = np.vstack(np.hsplit(np.hstack((b_old.T.astype(B.dtype), np.zeros([m, L - l]))), L / 32))
b_new = b_new.reshape([m, L])

a = tvm.nd.array(a_new.astype(A.dtype), ctx)
b = tvm.nd.array(b_new.astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((m, N_out), dtype=C.dtype), ctx)
f(a, b, c)
# Transform output c from layout CUBLASLT_ORDER_COL32 to row major layout
c_out = c.asnumpy()
c_out = c_out.reshape([int(m * N_out / 32), 32])
c_out = np.hstack(np.vsplit(c_out, int(N_out / 32)))
c_out = c_out[:, :n]
c_out = c_out.T
tvm.testing.assert_allclose(
c_out, np.dot(a_old.astype(C.dtype), b_old.astype(C.dtype)), rtol=rtol)
verify()

def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
j = 16
n = 1024
Expand Down Expand Up @@ -73,11 +132,14 @@ def verify(target="cuda"):
verify()

def test_matmul_add():
verify_matmul_add('float', 'float')
verify_matmul_add('float', 'float', rtol=1e-3)
masahi marked this conversation as resolved.
Show resolved Hide resolved
verify_matmul_add('float16', 'float')
verify_matmul_add('float16', 'float16', rtol=1e-2)
verify_matmul_add('int8', 'int32')

def test_matmul_add_igemm():
verify_matmul_add_igemm('int8', 'int32')

def test_batch_matmul():
verify_batch_matmul('float', 'float')
verify_batch_matmul('float16', 'float')
Expand All @@ -86,4 +148,5 @@ def test_batch_matmul():
if __name__ == "__main__":
test_matmul_add()
test_batch_matmul()
test_matmul_add_igemm()