Skip to content

Commit

Permalink
Split up GPU initialization into separate subroutine
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Hatfield committed Jul 29, 2019
1 parent badf17b commit bd325f8
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
19 changes: 12 additions & 7 deletions cublas_gemm_c.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ __global__ void double2half(half *out, const double *in, int n) {
}
}

cublasHandle_t cublasHandle;

// Sets up GPU and cuBLAS and allocates memory
extern "C" {
void init_gpu_c(int m, int n, int k) {
cudaSetDevice(0);
cublasErrCheck(cublasCreate(&cublasHandle));
cudaDeviceReset();
cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH));
}
}

// Performs matrix-matrix multiplication using Tensor Core.
extern "C" {
void tcgemm_c(int transa, int transb, int m, int n, int k, float alpha, void *a_p, int lda, void *b_p,
Expand All @@ -41,12 +53,6 @@ extern "C" {
// Compute GEMM using Tensor Core
// =========================================================================

// Set up GPU and cuBLAS
cublasHandle_t cublasHandle;
cudaSetDevice(0);
cudaDeviceReset();
cublasErrCheck(cublasCreate(&cublasHandle));

// Set up device-side arrays
double *a_d, *b_d;
half *a_d_16, *b_d_16;
Expand All @@ -71,7 +77,6 @@ extern "C" {
cudaDeviceSynchronize();

// Perform GEMM with Tensor Core
cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH));
cublasErrCheck(
cublasGemmEx(
cublasHandle, (cublasOperation_t)transa, (cublasOperation_t)transb,
Expand Down
9 changes: 9 additions & 0 deletions cublas_gemm_f.f90
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@ subroutine tcgemm_c(transa, transb, m, n, k, alpha, a_p, lda, b_p, ldb, beta, c_
end subroutine
end interface

interface
subroutine init_gpu_c() bind(c)
end subroutine
end interface

contains
subroutine init_gpu
call init_gpu_c
end subroutine

!> Perform matrix-matrix multiplication using Tensor Core (wrapper for C
! function).
subroutine tcgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
Expand Down
4 changes: 3 additions & 1 deletion matmul_test.f90
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
program matmul_test
use cublas_gemm_f, only: tcgemm
use cublas_gemm_f, only: init_gpu, tcgemm

implicit none

Expand Down Expand Up @@ -39,6 +39,8 @@ program matmul_test
! Device DGEMM (with transpose)
! =========================================================================

call init_gpu

! Call Tensor Core GEMM routine
call cpu_time(tick)
call tcgemm("N", "T", m, m, n, 1.0, a2, m, b2, m, 0.0, c2, m)
Expand Down

0 comments on commit bd325f8

Please sign in to comment.