-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
It compiles but currently throws a runtime error.
- Loading branch information
Sam Hatfield
committed
Jun 25, 2019
1 parent
9e54650
commit 54da1d3
Showing
3 changed files
with
85 additions
and
14 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
#include <stdio.h> | ||
#include <cublas_v2.h> | ||
|
||
// Handles CUDA errors | ||
#define cudaErrCheck(stat) { cudaErrCheck_((stat), __FILE__, __LINE__); } | ||
void cudaErrCheck_(cudaError_t stat, const char *file, int line) { | ||
if (stat != cudaSuccess) { | ||
fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(stat), file, line); | ||
} | ||
} | ||
|
||
// Handles cuBLAS errors | ||
#define cublasErrCheck(stat) { cublasErrCheck_((stat), __FILE__, __LINE__); } | ||
void cublasErrCheck_(cublasStatus_t stat, const char *file, int line) { | ||
if (stat != CUBLAS_STATUS_SUCCESS) { | ||
fprintf(stderr, "cuBLAS Error: %d %s %d\n", stat, file, line); | ||
} | ||
} | ||
|
||
// Performs matrix-matrix multiplication using Tensor Core. | ||
extern "C" { | ||
void tcgemm_c(char transa, char transb, int m, int n, int k, double alpha, void *a_p, int lda, void *b_p, | ||
int ldb, double beta, void *c_p, int ldc) { | ||
|
||
// Set up host-side arrays | ||
double *a_h, *b_h, *c_h; | ||
a_h = (double *)a_p; | ||
b_h = (double *)b_p; | ||
|
||
// ========================================================================= | ||
// Compute GEMM using Tensor Core | ||
// ========================================================================= | ||
|
||
// Set up GPU and cuBLAS | ||
cublasHandle_t cublasHandle; | ||
cudaSetDevice(0); | ||
cudaDeviceReset(); | ||
cublasErrCheck(cublasCreate(&cublasHandle)); | ||
|
||
// Set up device-side arrays | ||
double *a_d, *b_d, *c_d; | ||
|
||
// Allocate memory on device for all arrays | ||
// TODO: should the dimensions used below (m*k etc.) take into account transa, lda etc.? | ||
cudaErrCheck(cudaMalloc((void **)&a_d, m*k*sizeof(double))); | ||
cudaErrCheck(cudaMalloc((void **)&b_d, k*n*sizeof(double))); | ||
cudaErrCheck(cudaMalloc((void **)&c_d, m*n*sizeof(double))); | ||
|
||
// Copy input arrays to device | ||
cudaErrCheck(cudaMemcpy(a_d, a_h, m*k*sizeof(double), cudaMemcpyHostToDevice)); | ||
cudaErrCheck(cudaMemcpy(b_d, b_h, k*n*sizeof(double), cudaMemcpyHostToDevice)); | ||
|
||
cublasOperation_t transa_int = (transa == 'N' || transa == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T; | ||
cublasOperation_t transb_int = (transb == 'N' || transb == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T; | ||
|
||
// Perform GEMM | ||
cublasErrCheck( | ||
cublasGemmEx( | ||
cublasHandle, transa_int, transb_int, | ||
m, n, k, | ||
&alpha, | ||
a_h, CUDA_R_64F, lda, | ||
b_h, CUDA_R_64F, ldb, | ||
&beta, | ||
c_h, CUDA_R_64F, ldc, | ||
CUDA_R_64F, | ||
CUBLAS_GEMM_DEFAULT | ||
) | ||
); | ||
|
||
// Copy results back from device to host | ||
cudaErrCheck(cudaMemcpy(c_d, c_h, m*n*sizeof(double), cudaMemcpyDeviceToHost)); | ||
cudaDeviceSynchronize(); | ||
|
||
// ========================================================================= | ||
|
||
// Set incoming C array pointer | ||
c_p = (void *)c_h; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,18 @@ | ||
# Default compilers | ||
FC = gfortran | ||
NVCC = gcc | ||
NVCC = nvcc | ||
|
||
matmul_test: matmul_test.o cublas_gemm_f.o cublas_gemm_c.o | ||
$(FC) matmul_test.o cublas_gemm_f.o cublas_gemm_c.o -o matmul_test #-L$(CUDA)/lib64 -lcudart -lcublas -o matmul_test | ||
$(FC) matmul_test.o cublas_gemm_f.o cublas_gemm_c.o -L$(CUDA)/lib64 -lcudart -lcublas -o matmul_test | ||
|
||
matmul_test.o: cublas_gemm_f.o | ||
|
||
%.o: %.f90 | ||
$(FC) -c $< -o $@ | ||
|
||
%.o: %.c | ||
$(NVCC) -c $< -std=c99 | ||
%.o: %.cu | ||
$(NVCC) -c $< | ||
|
||
.PHONY: clean | ||
clean: | ||
rm -f *.o matmul_test | ||
rm -f *.o *.mod matmul_test |