Skip to content

Commit

Permalink
Update with batched gemm for openacc ccsd implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
omarkahmed committed Jul 8, 2024
1 parent 80cff68 commit c697434
Show file tree
Hide file tree
Showing 3 changed files with 323 additions and 5 deletions.
3 changes: 3 additions & 0 deletions src/ccsd/GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ ifdef USE_OPENACC_TRPDRV
endif
endif

ifdef USE_BATCHDGEMM_TRPDRV
FOPTIONS += -DUSE_BATCHDGEMM_TRPDRV
endif

ifeq ($(ARMCI_NETWORK),MPI-PR)
LIB_DEFINES += -DACC_STRIPS
Expand Down
143 changes: 138 additions & 5 deletions src/ccsd/ccsd_trpdrv_openacc.F
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,22 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
integer(INT32) :: shi
type(cublasHandle) :: handle(8)
integer(kind=cuda_stream_kind) :: stream(8)

#ifdef USE_BATCHDGEMM_TRPDRV
integer(INT32) :: nv4, no4 ! cublasDgemm requires 32-bit integers
integer(INT32), parameter :: cu_op_n = CUBLAS_OP_N
integer(INT32), parameter :: cu_op_t = CUBLAS_OP_T
integer(INT32), parameter :: batch1=4, batch2=4, batch3=8
integer, parameter :: cu_op_n = CUBLAS_OP_N
integer, parameter :: cu_op_t = CUBLAS_OP_T
type(c_devptr), device, dimension(batch1) :: a_array1
type(c_devptr), device, dimension(batch1) :: b_array1
type(c_devptr), device, dimension(batch1) :: c_array1
type(c_devptr), device, dimension(batch2) :: a_array2
type(c_devptr), device, dimension(batch2) :: b_array2
type(c_devptr), device, dimension(batch2) :: c_array2
type(c_devptr), device, dimension(batch3) :: a_array3
type(c_devptr), device, dimension(batch3) :: b_array3
type(c_devptr), device, dimension(batch3) :: c_array3
#endif
!
nodes = ga_nnodes()
me = ga_nodeid()
Expand All @@ -133,6 +146,7 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
err = cublasSetStream(handle(shi), stream(shi))
if (err.ne.0) call errquit('cublasSetStream',err,UNKNOWN_ERR)
end do

!
! device-only temp arrays
! produced by DGEMM, consumed by TENGY
Expand Down Expand Up @@ -404,7 +418,6 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
if (err.ne.0) then
call errquit('cudaMemcpyAsync',err,UNKNOWN_ERR)
endif

! arrays and thus copies contribute to more than one CUBLAS call
! but the copies on streams 1:4 and 5:8 are separable.
do shi=1,4
Expand Down Expand Up @@ -443,7 +456,127 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
else
call qexit('accwait',0)
endif
#ifdef USE_BATCHDGEMM_TRPDRV


! first two batches can be executed
! simultaneously


a_array1(1) = c_devloc(xJia(:))
a_array1(2) = c_devloc(xKia(:))
a_array1(3) = c_devloc(xJka(1+(k-klo)*lnvv:))
a_array1(4) = c_devloc(xKka(1+(k-klo)*lnvv:))
a_array2(1) = c_devloc(xJia(:))
a_array2(2) = c_devloc(xKia(:))
a_array2(3) = c_devloc(xJka(1+(k-klo)*lnvv:))
a_array2(4) = c_devloc(xKka(1+(k-klo)*lnvv:))

b_array1(1) = c_devloc(xTkj(1+(k-klo)*lnvv:))
b_array1(2) = c_devloc(xTkj(1+(k-klo)*lnvv:))
b_array1(3) = c_devloc(xTij(:))
b_array1(4) = c_devloc(xTij(:))

b_array2(1) = c_devloc(xTkj(1+(k-klo)*lnvv:))
b_array2(2) = c_devloc(xTkj(1+(k-klo)*lnvv:))
b_array2(3) = c_devloc(xTij(:))
b_array2(4) = c_devloc(xTij(:))

c_array1(1) = c_devloc(f1n(:,:))
c_array1(2) = c_devloc(f2n(:,:))
c_array1(3) = c_devloc(f1t(:,:))
c_array1(4) = c_devloc(f2t(:,:))

c_array2(1) = c_devloc(f3n(:,:))
c_array2(2) = c_devloc(f4n(:,:))
c_array2(3) = c_devloc(f3t(:,:))
c_array2(4) = c_devloc(f4t(:,:))

! third batch executed afterwards

a_array3(1) = c_devloc(xTia(:))
a_array3(2) = c_devloc(xXia(:))
a_array3(3) = c_devloc(xTia(:))
a_array3(4) = c_devloc(xXia(:))
a_array3(5) = c_devloc(xTka(1+(k-klo)*lnov:))
a_array3(6) = c_devloc(xXka(1+(k-klo)*lnov:))
a_array3(7) = c_devloc(xTka(1+(k-klo)*lnov:))
a_array3(8) = c_devloc(xXka(1+(k-klo)*lnov:))

b_array3(1) = c_devloc(xKkj(1+(k-klo)*lnov:))
b_array3(2) = c_devloc(xKkj(1+(k-klo)*lnov:))
b_array3(3) = c_devloc(xJkj(1+(k-klo)*lnov:))
b_array3(4) = c_devloc(xJkj(1+(k-klo)*lnov:))
b_array3(5) = c_devloc(xKij(:))
b_array3(6) = c_devloc(xKij(:))
b_array3(7) = c_devloc(xJij(:))
b_array3(8) = c_devloc(xJij(:))

c_array3(1) = c_devloc(f1n(:,:))
c_array3(2) = c_devloc(f2n(:,:))
c_array3(3) = c_devloc(f3n(:,:))
c_array3(4) = c_devloc(f4n(:,:))
c_array3(5) = c_devloc(f1t(:,:))
c_array3(6) = c_devloc(f2t(:,:))
c_array3(7) = c_devloc(f3t(:,:))
c_array3(8) = c_devloc(f4t(:,:))

err = cublasDgemmBatched_v2(
& handle(1),
& CUBLAS_OP_N,
& CUBLAS_OP_T,
& nv4, nv4, nv4,
& 1.0d0,
& a_array1, nv4,
& b_array1, nv4,
& 0.0d0,
& c_array1, nv4,
& batch1)
if (err.ne.0) then
call errquit('cublasDgemmBatched_v2',err,
& UNKNOWN_ERR)
endif

err = cublasDgemmBatched_v2(
& handle(2),
& CUBLAS_OP_N,
& CUBLAS_OP_N,
& nv4, nv4, nv4,
& 1.0d0,
& a_array2, nv4,
& b_array2, nv4,
& 0.0d0,
& c_array2, nv4,
& batch2)

if (err.ne.0) then
call errquit('cublasDgemmBatched_v2',err,
& UNKNOWN_ERR)
endif

err = cudaDeviceSynchronize()
if (err.ne.0) then
call errquit('cudaDeviceSync',err,UNKNOWN_ERR)
endif

err = cublasDgemmBatched_v2(
& handle(1),
& CUBLAS_OP_N,
& CUBLAS_OP_N,
& nv4, nv4, no4,
& -1.0d0,
& a_array3, nv4,
& b_array3, no4,
& 1.0d0,
& c_array3, nv4,
& batch3)
if (err.ne.0) then
call errquit('cublasGemmBatched_v2',
& err,UNKNOWN_ERR)
endif


#else
err = cublasDgemm_v2(handle(1),
& cu_op_n,cu_op_t,
& nv4,nv4,nv4,1.0d0,
Expand Down Expand Up @@ -588,12 +721,11 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
if (err.ne.0) then
call errquit('cublasDgemm_v2',err,UNKNOWN_ERR)
endif

#endif
err = cudaDeviceSynchronize()
if (err.ne.0) then
call errquit('cudaDeviceSync',err,UNKNOWN_ERR)
endif

! 8 pairs of DGEMM w/ VVV and VVO cost, 2 for FMA
dgemm_flops = 8*nvir*nvir*(nocc+nvir)*2
agg_flops = agg_flops + dgemm_flops
Expand Down Expand Up @@ -761,6 +893,7 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
deallocate( Tij, Tkj, Tia, Tka, Xia, Xka,
& Jia, Jka, Kia, Kka, Jij, Jkj, Kij, Kkj,
& Dja, Djka, Djia, stat=alloc_error)

if (alloc_error.ne.0) call errquit('free TKJKD',1,MA_ERR)
deallocate( xTij, xTkj, xTia, xTka, xXia, xXka,
& xJia, xJka, xKia, xKka, xJij, xJkj, xKij, xKkj,
Expand Down
Loading

0 comments on commit c697434

Please sign in to comment.