Skip to content

Commit

Permalink
Switch to Tensor Core ("tcgemm")
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Hatfield committed Jun 25, 2019
1 parent f462007 commit b81e954
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 28 deletions.
55 changes: 36 additions & 19 deletions cublas_gemm_c.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,25 @@ void cublasErrCheck_(cublasStatus_t stat, const char *file, int line) {
}
}

// Converts from double-precision to half-precision (CUDA kernel)
__global__ void double2half(half *out, const double *in, int n) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx < n) {
out[idx] = __float2half((float)(in[idx]));
}
}

// Performs matrix-matrix multiplication using Tensor Core.
extern "C" {
void tcgemm_c(char transa, char transb, int m, int n, int k, double alpha, void *a_p, int lda, void *b_p,
int ldb, double beta, void *c_p, int ldc) {
void tcgemm_c(char transa, char transb, int m, int n, int k, float alpha, void *a_p, int lda, void *b_p,
int ldb, float beta, void *c_p, int ldc) {

// Set up host-side arrays
double *a_h, *b_h, *c_h;
double *a_h, *b_h;
float *c_h;
a_h = (double *)a_p;
b_h = (double *)b_p;
c_h = (double *)c_p;
c_h = (float *)c_p;

// =========================================================================
// Compute GEMM using Tensor Core
Expand All @@ -39,48 +48,56 @@ extern "C" {
cublasErrCheck(cublasCreate(&cublasHandle));

// Set up device-side arrays
double *a_d, *b_d, *c_d;
double *a_d, *b_d;
half *a_d_16, *b_d_16;
float *c_d_32;

// Allocate memory on device for all arrays
// TODO: should the dimensions used below (m*k etc.) take into account transa, lda etc.?
cudaErrCheck(cudaMalloc((void **)&a_d, m*k*sizeof(double)));
cudaErrCheck(cudaMalloc((void **)&b_d, k*n*sizeof(double)));
cudaErrCheck(cudaMalloc((void **)&c_d, m*n*sizeof(double)));
cudaErrCheck(cudaMalloc((void**)&a_d_16, m*k*sizeof(half)));
cudaErrCheck(cudaMalloc((void**)&b_d_16, k*n*sizeof(half)));
cudaErrCheck(cudaMalloc((void**)&c_d_32, m*n*sizeof(float)));

// Copy input arrays to device
cudaErrCheck(cudaMemcpy(a_d, a_h, m*k*sizeof(double), cudaMemcpyHostToDevice));
cudaErrCheck(cudaMemcpy(b_d, b_h, k*n*sizeof(double), cudaMemcpyHostToDevice));

// Convert arrays to half-precision
double2half<<<(int)((m*k)/256) + 1, 256 >>>(a_d_16, a_d, m*k);
double2half<<<(int)((k*n)/256) + 1, 256 >>>(b_d_16, b_d, k*n);

cudaDeviceSynchronize();

cublasOperation_t transa_op = (transa == 'N' || transa == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t transb_op = (transb == 'N' || transb == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T;

// Perform GEMM
// Perform GEMM with Tensor Core
cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH));
cublasErrCheck(
cublasGemmEx(
cublasHandle, transa_op, transb_op,
m, n, k,
&alpha,
a_d, CUDA_R_64F, lda,
b_d, CUDA_R_64F, ldb,
a_d_16, CUDA_R_16F, lda,
b_d_16, CUDA_R_16F, ldb,
&beta,
c_d, CUDA_R_64F, ldc,
CUDA_R_64F,
CUBLAS_GEMM_DEFAULT
c_d_32, CUDA_R_32F, ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
)
);

// Copy results back from device to host
cudaErrCheck(cudaMemcpy(c_h, c_d, m*n*sizeof(double), cudaMemcpyDeviceToHost));
cudaErrCheck(cudaMemcpy(c_h, c_d_32, m*n*sizeof(float), cudaMemcpyDeviceToHost));
cudaDeviceSynchronize();

// Free memory on device
cudaErrCheck(cudaFree(a_d));
cudaErrCheck(cudaFree(b_d));
cudaErrCheck(cudaFree(c_d));

// =========================================================================

// Set incoming C array pointer
//c_p = (void *)c_h;
cudaErrCheck(cudaFree(a_d_16));
cudaErrCheck(cudaFree(b_d_16));
cudaErrCheck(cudaFree(c_d_32));
}
}
11 changes: 6 additions & 5 deletions cublas_gemm_f.f90
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ module cublas_gemm_f
! function).
subroutine tcgemm_c(transa, transb, m, n, k, alpha, a_p, lda, b_p, ldb, beta, c_p, ldc) &
& bind(c, name="tcgemm_c")
use iso_c_binding, only: c_char, c_int, c_double, c_ptr
use iso_c_binding, only: c_char, c_int, c_float, c_ptr
character(kind=c_char), value :: transa, transb
integer(kind=c_int), value :: m, n, k, lda, ldb, ldc
real(kind=c_double), value :: alpha, beta
real(kind=c_float), value :: alpha, beta
type(c_ptr), value :: a_p, b_p, c_p
end subroutine
end interface
Expand All @@ -24,8 +24,9 @@ subroutine tcgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)

character :: transa, transb
integer :: m, n, k, lda, ldb, ldc
real(8) :: alpha, beta
real(8), target :: a(:,:), b(:,:), c(:,:)
real(4) :: alpha, beta
real(8), target :: a(:,:), b(:,:)
real(4), target :: c(:,:)
type(c_ptr) :: a_p, b_p, c_p

! Copy data to C pointers
Expand All @@ -34,6 +35,6 @@ subroutine tcgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
c_p = c_loc(c(1,1))

! Call C function
call tcgemm_c("N", "N", n, n, n, 1.0d0, a_p, n, b_p, n, 0.0d0, c_p, n)
call tcgemm_c("N", "N", n, n, n, alpha, a_p, n, b_p, n, beta, c_p, n)
end subroutine
end module
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ FC = gfortran
NVCC = nvcc

matmul_test: matmul_test.o cublas_gemm_f.o cublas_gemm_c.o
$(FC) matmul_test.o cublas_gemm_f.o cublas_gemm_c.o -L$(CUDA)/lib64 -lcudart -lcublas -o matmul_test
$(FC) matmul_test.o cublas_gemm_f.o cublas_gemm_c.o -L$(CUDA)/lib64 -lcudart -lcublas -lstdc++ -o matmul_test

matmul_test.o: cublas_gemm_f.o

Expand Down
7 changes: 4 additions & 3 deletions matmul_test.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ program matmul_test
integer, parameter :: n = 10

! Host matrices
real(8), dimension(n,n) :: a1, b1, c1, a2, b2, c2
real(8), dimension(n,n) :: a1, b1, c1, a2, b2
real(4), dimension(n,n) :: c2

integer :: i, j

Expand All @@ -34,9 +35,9 @@ program matmul_test
! =========================================================================

! Call Tensor Core GEMM routine
call tcgemm("N", "N", n, n, n, 1.0d0, a2, n, b2, n, 0.0d0, c2, n)
call tcgemm("N", "N", n, n, n, 1.0, a2, n, b2, n, 0.0, c2, n)

write (*,"(A35,F13.10)") "C matrix Frobenius norm (device) = ", frob_norm(c2)
write (*,"(A35,F13.10)") "C matrix Frobenius norm (device) = ", frob_norm(real(c2,8))

contains
! Computes Frobenius norm of input matrix
Expand Down

0 comments on commit b81e954

Please sign in to comment.