Skip to content

Commit

Permalink
Convert transa and transb parameters to ints
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Hatfield committed Jun 25, 2019
1 parent b81e954 commit 935b4d3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
7 changes: 2 additions & 5 deletions cublas_gemm_c.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ __global__ void double2half(half *out, const double *in, int n) {

// Performs matrix-matrix multiplication using Tensor Core.
extern "C" {
void tcgemm_c(char transa, char transb, int m, int n, int k, float alpha, void *a_p, int lda, void *b_p,
void tcgemm_c(int transa, int transb, int m, int n, int k, float alpha, void *a_p, int lda, void *b_p,
int ldb, float beta, void *c_p, int ldc) {

// Set up host-side arrays
Expand Down Expand Up @@ -70,14 +70,11 @@ extern "C" {

cudaDeviceSynchronize();

cublasOperation_t transa_op = (transa == 'N' || transa == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t transb_op = (transb == 'N' || transb == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T;

// Perform GEMM with Tensor Core
cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH));
cublasErrCheck(
cublasGemmEx(
cublasHandle, transa_op, transb_op,
cublasHandle, (cublasOperation_t)transa, (cublasOperation_t)transb,
m, n, k,
&alpha,
a_d_16, CUDA_R_16F, lda,
Expand Down
11 changes: 8 additions & 3 deletions cublas_gemm_f.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ module cublas_gemm_f
!> Perform matrix-matrix multiplication using Tensor Core (interface for C
! function).
subroutine tcgemm_c(transa, transb, m, n, k, alpha, a_p, lda, b_p, ldb, beta, c_p, ldc) &
& bind(c, name="tcgemm_c")
& bind(c)
use iso_c_binding, only: c_char, c_int, c_float, c_ptr
character(kind=c_char), value :: transa, transb
integer(kind=c_int), value :: transa, transb
integer(kind=c_int), value :: m, n, k, lda, ldb, ldc
real(kind=c_float), value :: alpha, beta
type(c_ptr), value :: a_p, b_p, c_p
Expand All @@ -28,13 +28,18 @@ subroutine tcgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
real(8), target :: a(:,:), b(:,:)
real(4), target :: c(:,:)
type(c_ptr) :: a_p, b_p, c_p
integer :: transa_l, transb_l

! Copy data to C pointers
a_p = c_loc(a(1,1))
b_p = c_loc(b(1,1))
c_p = c_loc(c(1,1))

! TODO: Figure out how to pass single character strings to C
transa_l = merge(0, 1, transa == "N" .or. transa == "n")
transb_l = merge(0, 1, transb == "N" .or. transb == "n")

! Call C function
call tcgemm_c("N", "N", n, n, n, alpha, a_p, n, b_p, n, beta, c_p, n)
call tcgemm_c(transa_l, transb_l, n, n, n, alpha, a_p, n, b_p, n, beta, c_p, n)
end subroutine
end module

0 comments on commit 935b4d3

Please sign in to comment.