Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Implement wrappers for WMMA LLVM intrinsics #494

Merged
merged 91 commits into from
Feb 3, 2020

Conversation

thomasfaingnaert
Copy link
Member

@thomasfaingnaert thomasfaingnaert commented Nov 9, 2019

This PR adds low-level wrappers around the LLVM WMMA intrinsics.
There is a one-to-one mapping between Julia functions and the LLVM intrinsics, which means that the function names can be very long.
The return types are the Julia types that correspond closest to the return type of the LLVM intrinsic (e.g. [8 x <2 x half>] becomes NTuple{8, NTuple{2, VecElement{Float16}}}).
In essence, these wrappers return the SSA nodes returned by the LLVM intrinsic.

Once this PR is finalised, I will start on a higher level API, similar to how WMMA is used in CUDA C++.

I added all intrinsics available in LLVM 6, PTX 6.0, SM 70, with the following exceptions:

  • The load/store intrinsics have a version without a stride parameter. In that case, the stride is derived from the datatype of the arguments and the WMMA shape. The same behaviour can be achieved by explicitly specifying that stride, so I decided to leave the strideless version out.
  • The MMA intrinsic can use saturation arithmetic. However, this is deprecated for floating point operations starting from PTX 6.4, so I decided not to add it.

Example usage:

Julia code
using CUDAnative
using CuArrays
using Test

# Generate input matrices
a     = rand(Float16, (16, 16))
a_dev = CuArray(a)
b     = rand(Float16, (16, 16))
b_dev = CuArray(b)
c     = rand(Float32, (16, 16))
c_dev = CuArray(c)

# Allocate space for result
d_dev = similar(c_dev)

# Matrix multiply-accumulate kernel (D = A * B + C)
function kernel(a_dev, b_dev, c_dev, d_dev)
    a_frag = llvm_wmma_load_a_col_m16n16k16_stride_f16(pointer(a_dev), 16)
    b_frag = llvm_wmma_load_b_col_m16n16k16_stride_f16(pointer(b_dev), 16)
    c_frag = llvm_wmma_load_c_col_m16n16k16_stride_f32(pointer(c_dev), 16)

    d_frag = llvm_wmma_mma_col_col_m16n16k16_f32_f32(a_frag, b_frag, c_frag)

    llvm_wmma_store_d_col_m16n16k16_stride_f32(pointer(d_dev), d_frag, 16)
    return
end

@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev)
@test a * b + c  Array(d_dev) rtol=0.01

This will be compiled to the following LLVM IR:

LLVM IR
%src_ptr.i.i = inttoptr i64 %.fca.1.extract15 to i8*
%ret.llvm.i.i = call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.load.a.sync.col.m16n16k16.stride.f16(i8* %src_ptr.i.i, i32 16)
%ret.llvm.0.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 0
%ret.llvm.1.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 1
%ret.llvm.2.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 2
%ret.llvm.3.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 3
%ret.llvm.4.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 4
%ret.llvm.5.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 5
%ret.llvm.6.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 6
%ret.llvm.7.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 7
%src_ptr.i5.i = inttoptr i64 %.fca.1.extract9 to i8*
%ret.llvm.i6.i = call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.load.b.sync.col.m16n16k16.stride.f16(i8* %src_ptr.i5.i, i32 16)
%ret.llvm.0.i7.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 0
%ret.llvm.1.i8.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 1
%ret.llvm.2.i9.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 2
%ret.llvm.3.i10.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 3
%ret.llvm.4.i11.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 4
%ret.llvm.5.i12.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 5
%ret.llvm.6.i13.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 6
%ret.llvm.7.i14.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 7
%src_ptr.i31.i = inttoptr i64 %.fca.1.extract3 to i8*
%ret.llvm.i32.i = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.load.c.sync.col.m16n16k16.stride.f32(i8* %src_ptr.i31.i, i32 16)
%ret.llvm.0.i33.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 0
%ret.llvm.1.i34.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 1
%ret.llvm.2.i35.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 2
%ret.llvm.3.i36.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 3
%ret.llvm.4.i37.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 4
%ret.llvm.5.i38.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 5
%ret.llvm.6.i39.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 6
%ret.llvm.7.i40.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 7
%d.llvm.i.i = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.mma.sync.col.col.m16n16k16.f32.f32(<2 x half> %ret.llvm.0.i.i, <2 x half> %ret.llvm.1.i.i, <2 x half> %ret.llvm.2.i.i, <2 x half> %ret.llvm.3.i.i, <2 x half> %ret.llvm.4.i.i, <2 x half> %ret.llvm.5.i.i, <2 x half> %ret.llvm.6.i.i, <2 x half> %ret.llvm.7.i.i, <2 x half> %ret.llvm.0.i7.i, <2 x half> %ret.llvm.1.i8.i, <2 x half> %ret.llvm.2.i9.i, <2 x half> %ret.llvm.3.i10.i, <2 x half> %ret.llvm.4.i11.i, <2 x half> %ret.llvm.5.i12.i, <2 x half> %ret.llvm.6.i13.i, <2 x half> %ret.llvm.7.i14.i, float %ret.llvm.0.i33.i, float %ret.llvm.1.i34.i, float %ret.llvm.2.i35.i, float %ret.llvm.3.i36.i, float %ret.llvm.4.i37.i, float %ret.llvm.5.i38.i, float %ret.llvm.6.i39.i, float %ret.llvm.7.i40.i)
%d.llvm.0.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 0
%d.llvm.1.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 1
%d.llvm.2.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 2
%d.llvm.3.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 3
%d.llvm.4.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 4
%d.llvm.5.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 5
%d.llvm.6.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 6
%d.llvm.7.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 7
%dst_ptr.i.i = inttoptr i64 %.fca.1.extract to i8*
call void @llvm.nvvm.wmma.store.d.sync.col.m16n16k16.stride.f32(i8* %dst_ptr.i.i, float %d.llvm.0.i.i, float %d.llvm.1.i.i, float %d.llvm.2.i.i, float %d.llvm.3.i.i, float %d.llvm.4.i.i, float %d.llvm.5.i.i, float %d.llvm.6.i.i, float %d.llvm.7.i.i, i32 16)
ret void

Note that all the bitcast, extractvalue and insertvalue instructions (necessary to generate correct LLVM IR for use with llvmcall) are optimised away. The remaining extractvalue is necessary to convert the struct return type to separate arguments.

Finally, the NVPTX backend generated the following PTX code:

PTX code
LBB1_8:                                 // %julia_kernel_3.exit
	mov.u32 	%r1, 16;
	wmma.load.a.sync.col.m16n16k16.f16 	{%hh1, %hh2, %hh3, %hh4, %hh5, %hh6, %hh7, %hh8}, [%rd1], %r1;
	wmma.load.b.sync.col.m16n16k16.f16 	{%hh9, %hh10, %hh11, %hh12, %hh13, %hh14, %hh15, %hh16}, [%rd2], %r1;
	wmma.load.c.sync.col.m16n16k16.f32 	{%f1, %f2, %f3, %f4, %f5, %f6, %f7, %f8}, [%rd3], %r1;
	wmma.mma.sync.col.col.m16n16k16.f32.f32
		{%f9, %f10, %f11, %f12, %f13, %f14, %f15, %f16},
		{%hh1, %hh2, %hh3, %hh4, %hh5, %hh6, %hh7, %hh8},
		{%hh9, %hh10, %hh11, %hh12, %hh13, %hh14, %hh15, %hh16},
		{%f1, %f2, %f3, %f4, %f5, %f6, %f7, %f8};
	wmma.store.d.sync.col.m16n16k16.f32 	[%rd4], {%f9, %f10, %f11, %f12, %f13, %f14, %f15, %f16}, %r1;
	ret;

TODO/Questions:

  • I should probably add documentation for this, or should I leave this for the higher-level API?
  • Would you prefer to have the non-stride versions anyway, or can I leave these out?
  • The loads and stores are tested with the default address space (generic), using global arrays. Should I add tests for the intrinsic versions with global and shared address spaces as well?

@vchuravy
Copy link
Member

Thanks! This is a really nice first PR.

Should I add tests for the intrinsic versions with global and shared address spaces as well?

I would say definitely for the shared AS.

I will leave the detailed review up to Tim, but from my perspective it would be nice to shorten the Julia functions a bit,
it seems like we could drop the llvm prefix and the eltype suffix since the eltype is inferable from the Ptr types?

Secondly (and more a matter of style) I prefer using LLVM.jl to generate the functions instead of doing string interpolation.
As an example see

@generated function unsafe_cached_load(p::DevicePtr{T,AS.Global}, i::Integer=1,

@maleadt
Copy link
Member

maleadt commented Nov 13, 2019

Looks good! Appreciate the documentation.

Did you look into ccall(..., llvmcall, ...) for automatic type conversions (i.e. without hard-coding type maps as you do now)? It's very much possible that the conversions don't match what the WMMA intrinsics need, but at first sight it looks pretty good:

julia> foo(x) = ccall("llvm.donothing", llvmcall, Nothing, (NTuple{8, NTuple{2, VecElement{Float16}}},), x)
foo (generic function with 2 methods)

julia> code_llvm(foo, Tuple{NTuple{8, NTuple{2, VecElement{Float16}}}}; optimize=false)

;  @ REPL[19]:1 within `foo'
define void @julia_foo_16071([8 x <2 x i16>] addrspace(11)* nocapture nonnull readonly dereferenceable(32)) {
top:
  %1 = call %jl_value_t*** @julia.ptls_states()
  %2 = bitcast %jl_value_t*** %1 to %jl_value_t addrspace(10)**
  %3 = getelementptr inbounds %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %2, i64 4
  %4 = bitcast %jl_value_t addrspace(10)** %3 to i64**
  %5 = load i64*, i64** %4
  %6 = load [8 x <2 x i16>], [8 x <2 x i16>] addrspace(11)* %0, align 4
  call void @llvm.donothing([8 x <2 x i16>] %6) [ "jl_roots"([8 x <2 x i16>] addrspace(11)* %0) ]
  ret void
}

The i16 instead of half might be a problem, although LLVM sometimes does not care (e.g. when passing i64's to intrinsics expecting a i8*). Be sure to use an assertions build when verifying that.

@vchuravy has there been any movement on JuliaLang/julia#26381? That still seems wanted, because IEEEFloat16.jl won't integrate with above ccall semantics.

@vchuravy
Copy link
Member

No movement on the f16 front we can't use f16 if the LLVM backend doesn't support it and so we can't enable it universally and need to special case it for the target backend, which is messy.

@thomasfaingnaert
Copy link
Member Author

@maleadt

Did you look into ccall(..., llvmcall, ...) for automatic type conversions (i.e. without hard-coding type maps as you do now)? It's very much possible that the conversions don't match what the WMMA intrinsics need, but at first sight it looks pretty good:
The i16 instead of half might be a problem, although LLVM sometimes does not care (e.g. when passing i64's to intrinsics expecting a i8*). Be sure to use an assertions build when verifying that.

It doesn't seem like passing i64 instead of an i8* works, at least when I enable LLVM assertions:

$ JULIA_LLVM_ARGS="--version" jl
LLVM (http:https://llvm.org/):
  LLVM version 6.0.1
  Optimized build with assertions.
  Default target: x86_64--linux-gnu
  Host CPU: ivybridge

The following snippet:

using CuArrays
using CUDAnative

d     = rand(Float32, (16, 16))
d_dev = CuArray(d)

function kernel(d_dev)
    ccall("extern llvm.nvvm.wmma.store.d.sync.col.m16n16k16.stride.f32", llvmcall, Nothing, (Int64, Float32, Float32, Float32, Float32, Float32, Float32, Float32, Float32, Int32), pointer(d_dev), 1, 2, 3, 4, 5, 6, 7, 8, 16)
    return
end

@cuda threads=32 kernel(d_dev)

Array(d_dev)

gives an assertion failure:

Assertion failure (piped through c++filt)
julia: /home/tfaingna/src/julia/deps/srccache/llvm-6.0.1/include/llvm/Support/Casting.h:255: typename llvm::cast_retty<X, Y*>::ret_type llvm::cast(Y*) [with X = llvm::PointerType; Y = llvm::Type; typename llvm::cast_retty<X, Y*>::ret_type = llvm::PointerType*]: Assertion `isa<X>(Val) && "cast<Ty>() argument of incompatible type!"' failed.

signal (6): Aborted
in expression starting at REPL[6]:1
gsignal at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x7fdacc66f40e)
__assert_fail at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
llvm::cast_retty<llvm::PointerType, llvm::Type*>::ret_type llvm::cast<llvm::PointerType, llvm::Type>(llvm::Type*) at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
llvm::SelectionDAGBuilder::visitTargetIntrinsic(llvm::CallInst const&, unsigned int) at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
llvm::SelectionDAGBuilder::visitIntrinsicCall(llvm::CallInst const&, unsigned int) at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
llvm::SelectionDAGBuilder::visitCall(llvm::CallInst const&) at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
llvm::SelectionDAGBuilder::visit(llvm::Instruction const&) at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
llvm::SelectionDAGISel::SelectBasicBlock(llvm::ilist_iterator<llvm::ilist_detail::node_options<llvm::Instruction, true, false, void>, false, true>, llvm::ilist_iterator<llvm::ilist_detail::node_options<llvm::Instruction, true, false, void>, false, true>, bool&) at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
llvm::SelectionDAGISel::SelectAllBasicBlocks(llvm::Function const&) at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
llvm::SelectionDAGISel::runOnMachineFunction(llvm::MachineFunction&) [clone .part.973] at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
llvm::MachineFunctionPass::runOnFunction(llvm::Function&) at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
llvm::FPPassManager::runOnFunction(llvm::Function&) at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
llvm::FPPassManager::runOnModule(llvm::Module&) at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
llvm::legacy::PassManagerImpl::run(llvm::Module&) at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
LLVMTargetMachineEmit(LLVMOpaqueTargetMachine*, LLVMOpaqueModule*, llvm::raw_pwrite_stream&, LLVMCodeGenFileType, char**) at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
LLVMTargetMachineEmitToMemoryBuffer at /home/tfaingna/src/julia/usr/bin/../lib/libLLVM-6.0.so (unknown line)
macro expansion at /home/tfaingna/.julia/packages/LLVM/ICZSf/src/base.jl:18 [inlined]
LLVMTargetMachineEmitToMemoryBuffer at /home/tfaingna/.julia/packages/LLVM/ICZSf/lib/6.0/libLLVM_h.jl:2726 [inlined]
emit at /home/tfaingna/.julia/packages/LLVM/ICZSf/src/targetmachine.jl:42
mcgen at /home/tfaingna/.julia/dev/CUDAnative/src/compiler/mcgen.jl:87
#codegen#194 at /home/tfaingna/.julia/packages/TimerOutputs/Tf7lx/src/TimerOutput.jl:214
unknown function (ip: 0x7fda46efb85d)
#codegen at ./none:0 [inlined]
#compile#193 at /home/tfaingna/.julia/dev/CUDAnative/src/compiler/driver.jl:47
#compile at ./none:0 [inlined]
#compile#192 at /home/tfaingna/.julia/dev/CUDAnative/src/compiler/driver.jl:28 [inlined]
#compile at ./none:0 [inlined]
#compile at ./none:0 [inlined]
macro expansion at /home/tfaingna/.julia/dev/CUDAnative/src/execution.jl:403 [inlined]
#cufunction#236 at /home/tfaingna/.julia/dev/CUDAnative/src/execution.jl:368
jl_apply_generic at /home/tfaingna/src/julia/src/gf.c:2197
cufunction at /home/tfaingna/.julia/dev/CUDAnative/src/execution.jl:368
jl_apply_generic at /home/tfaingna/src/julia/src/gf.c:2197
do_call at /home/tfaingna/src/julia/src/interpreter.c:323
eval_value at /home/tfaingna/src/julia/src/interpreter.c:411
eval_body at /home/tfaingna/src/julia/src/interpreter.c:635
jl_interpret_toplevel_thunk_callback at /home/tfaingna/src/julia/src/interpreter.c:884
unknown function (ip: 0xfffffffffffffffe)
unknown function (ip: 0x7fda71ba278f)
unknown function (ip: 0xa)
jl_interpret_toplevel_thunk at /home/tfaingna/src/julia/src/interpreter.c:893
jl_toplevel_eval_flex at /home/tfaingna/src/julia/src/toplevel.c:815
jl_toplevel_eval_flex at /home/tfaingna/src/julia/src/toplevel.c:764
jl_toplevel_eval_in at /home/tfaingna/src/julia/src/toplevel.c:844
eval at ./boot.jl:330
jl_apply_generic at /home/tfaingna/src/julia/src/gf.c:2191
eval_user_input at /home/tfaingna/src/julia/usr/share/julia/stdlib/v1.2/REPL/src/REPL.jl:86
run_backend at /home/tfaingna/.julia/packages/Revise/Mlh6Z/src/Revise.jl:1033
#85 at ./task.jl:268
jl_apply_generic at /home/tfaingna/src/julia/src/gf.c:2197
jl_apply at /home/tfaingna/src/julia/src/julia.h:1614 [inlined]
start_task at /home/tfaingna/src/julia/src/task.c:596
unknown function (ip: 0xffffffffffffffff)
Allocations: 59443583 (Pool: 59433572; Big: 10011); GC: 128
Aborted

Everything works fine when disabling LLVM assertions.

Changing the type from Int64 to Ptr{Int8} also doesn't help, as they both seem to be converted to LLVM's i64.

@maleadt
Copy link
Member

maleadt commented Nov 14, 2019

Ha, so this isn't just broken for half/i16, but also for how we emit pointers (as literal integers). This was changed in JuliaLang/julia@2bb430e. It might be worth it to look into this and fix it (i.e. make sure we pass actual pointers to LLVM intrinsics when using ccall(..., llvmcall, ...) by adding an extra inttoptr), because we could then also special-case passing Float16 by bitcasting to half. That would significantly simplify this PR, I think.

@thomasfaingnaert
Copy link
Member Author

It might be worth it to look into this and fix it (i.e. make sure we pass actual pointers to LLVM intrinsics when using ccall(..., llvmcall, ...) by adding an extra inttoptr), because we could then also special-case passing Float16 by bitcasting to half. That would significantly simplify this PR, I think.

Would it make sense to implement this in https://github.com/JuliaLang/julia (so it works outside of the context of GPUs), or just using the compiler hooks in CUDAnative? Judging from the discussion at JuliaLang/julia#23367, anonymising pointers and other types was fully intended, even though it broke llvmcall, but things may have changed now that 1.0 is released?

@maleadt
Copy link
Member

maleadt commented Nov 14, 2019

Definitely, I meant this to be a fix in the Julia compiler. Anonimization was intended indeed, we should just fix our ABI when interfacing with LLVM intrinsics.

@thomasfaingnaert
Copy link
Member Author

thomasfaingnaert commented Nov 15, 2019

@maleadt While adding tests for the shared address space, I stumbled on two issues:

  1. @cuStaticSharedMem aligns on 16 byte boundaries, whereas the PTX WMMA instructions expect all memory addresses to be multiples of 32. Should we change the default alignment to 32 bytes, make the alignment a parameter to @cuStaticSharedMem, or just make the user perform a bitwise AND with -32?

  2. The ptr field of @cuStaticSharedMem's return value has type DevicePtr{T, AS.Shared}, but the bitpattern stored within it still corresponds with generic addressing (i.e. the base address of the shared window is not subtracted yet). I need to manually addrspacecast it to addressspace 3. Is this intended?

@thomasfaingnaert
Copy link
Member Author

bors try

bors bot added a commit that referenced this pull request Feb 1, 2020
@bors
Copy link
Contributor

bors bot commented Feb 1, 2020

try

Build succeeded

@thomasfaingnaert
Copy link
Member Author

I addressed your comments. The most major changes are:

  • Implement indexing for Fragments: when run with fragment arguments, getindex, setindex!, firstindex, lastindex will just run the corresponding function on the x member
  • Move all documentation to docstrings, making it available in the REPL (I did have to add dummy functions to avoid the "Not found error", not sure if there is a better way to do that)
  • Remove WMMA prefix of types and functions, and move everything into WMMA submodule

Feel free to comment on anything you'd still like to change.

bors try

bors bot added a commit that referenced this pull request Feb 1, 2020
@bors
Copy link
Contributor

bors bot commented Feb 1, 2020

try

Build succeeded

Finally, it is important to note that the resultant ``D`` matrix can be used as a ``C`` matrix for a subsequent multiply-accumulate.
This is useful if one needs to calculate a sum of the form ``\sum_{i=0}^{n} A_i B_i``, where ``A_i`` and ``B_i`` are matrices of the correct dimension.

## LLVM Intrinsics
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a user perspective I would want to first read about C-like/highlevel API and then if I am interested I care about the intrinsics.

@vchuravy
Copy link
Member

vchuravy commented Feb 2, 2020

Looks very nice!

@maleadt
Copy link
Member

maleadt commented Feb 3, 2020

Really nice! Let's go ahead and merge this.
bors r+

bors bot added a commit that referenced this pull request Feb 3, 2020
494: Implement wrappers for WMMA LLVM intrinsics r=maleadt a=thomasfaingnaert

This PR adds low-level wrappers around the LLVM WMMA intrinsics.
There is a one-to-one mapping between Julia functions and the LLVM intrinsics, which means that the function names can be very long.
The return types are the Julia types that correspond closest to the return type of the LLVM intrinsic (e.g. `[8 x <2 x half>]` becomes `NTuple{8, NTuple{2, VecElement{Float16}}}`).
In essence, these wrappers return the SSA nodes returned by the LLVM intrinsic.

Once this PR is finalised, I will start on a higher level API, similar to how WMMA is used in CUDA C++.

I added all intrinsics available in LLVM 6, PTX 6.0, SM 70, with the following exceptions:
- The load/store intrinsics have a version without a stride parameter. In that case, the stride is derived from the datatype of the arguments and the WMMA shape. The same behaviour can be achieved by explicitly specifying that stride, so I decided to leave the strideless version out.
- The MMA intrinsic can use saturation arithmetic. However, this is deprecated for floating point operations starting from PTX 6.4, so I decided not to add it.

Example usage:

<details>
<summary>Julia code</summary>

```julia
using CUDAnative
using CuArrays
using Test

# Generate input matrices
a     = rand(Float16, (16, 16))
a_dev = CuArray(a)
b     = rand(Float16, (16, 16))
b_dev = CuArray(b)
c     = rand(Float32, (16, 16))
c_dev = CuArray(c)

# Allocate space for result
d_dev = similar(c_dev)

# Matrix multiply-accumulate kernel (D = A * B + C)
function kernel(a_dev, b_dev, c_dev, d_dev)
    a_frag = llvm_wmma_load_a_col_m16n16k16_stride_f16(pointer(a_dev), 16)
    b_frag = llvm_wmma_load_b_col_m16n16k16_stride_f16(pointer(b_dev), 16)
    c_frag = llvm_wmma_load_c_col_m16n16k16_stride_f32(pointer(c_dev), 16)

    d_frag = llvm_wmma_mma_col_col_m16n16k16_f32_f32(a_frag, b_frag, c_frag)

    llvm_wmma_store_d_col_m16n16k16_stride_f32(pointer(d_dev), d_frag, 16)
    return
end

@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev)
@test a * b + c ≈ Array(d_dev) rtol=0.01
```
</details>

This will be compiled to the following LLVM IR:
<details>
<summary>LLVM IR</summary>

```llvm
%src_ptr.i.i = inttoptr i64 %.fca.1.extract15 to i8*
%ret.llvm.i.i = call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.load.a.sync.col.m16n16k16.stride.f16(i8* %src_ptr.i.i, i32 16)
%ret.llvm.0.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 0
%ret.llvm.1.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 1
%ret.llvm.2.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 2
%ret.llvm.3.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 3
%ret.llvm.4.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 4
%ret.llvm.5.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 5
%ret.llvm.6.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 6
%ret.llvm.7.i.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i.i, 7
%src_ptr.i5.i = inttoptr i64 %.fca.1.extract9 to i8*
%ret.llvm.i6.i = call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.load.b.sync.col.m16n16k16.stride.f16(i8* %src_ptr.i5.i, i32 16)
%ret.llvm.0.i7.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 0
%ret.llvm.1.i8.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 1
%ret.llvm.2.i9.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 2
%ret.llvm.3.i10.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 3
%ret.llvm.4.i11.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 4
%ret.llvm.5.i12.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 5
%ret.llvm.6.i13.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 6
%ret.llvm.7.i14.i = extractvalue { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } %ret.llvm.i6.i, 7
%src_ptr.i31.i = inttoptr i64 %.fca.1.extract3 to i8*
%ret.llvm.i32.i = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.load.c.sync.col.m16n16k16.stride.f32(i8* %src_ptr.i31.i, i32 16)
%ret.llvm.0.i33.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 0
%ret.llvm.1.i34.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 1
%ret.llvm.2.i35.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 2
%ret.llvm.3.i36.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 3
%ret.llvm.4.i37.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 4
%ret.llvm.5.i38.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 5
%ret.llvm.6.i39.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 6
%ret.llvm.7.i40.i = extractvalue { float, float, float, float, float, float, float, float } %ret.llvm.i32.i, 7
%d.llvm.i.i = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.mma.sync.col.col.m16n16k16.f32.f32(<2 x half> %ret.llvm.0.i.i, <2 x half> %ret.llvm.1.i.i, <2 x half> %ret.llvm.2.i.i, <2 x half> %ret.llvm.3.i.i, <2 x half> %ret.llvm.4.i.i, <2 x half> %ret.llvm.5.i.i, <2 x half> %ret.llvm.6.i.i, <2 x half> %ret.llvm.7.i.i, <2 x half> %ret.llvm.0.i7.i, <2 x half> %ret.llvm.1.i8.i, <2 x half> %ret.llvm.2.i9.i, <2 x half> %ret.llvm.3.i10.i, <2 x half> %ret.llvm.4.i11.i, <2 x half> %ret.llvm.5.i12.i, <2 x half> %ret.llvm.6.i13.i, <2 x half> %ret.llvm.7.i14.i, float %ret.llvm.0.i33.i, float %ret.llvm.1.i34.i, float %ret.llvm.2.i35.i, float %ret.llvm.3.i36.i, float %ret.llvm.4.i37.i, float %ret.llvm.5.i38.i, float %ret.llvm.6.i39.i, float %ret.llvm.7.i40.i)
%d.llvm.0.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 0
%d.llvm.1.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 1
%d.llvm.2.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 2
%d.llvm.3.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 3
%d.llvm.4.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 4
%d.llvm.5.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 5
%d.llvm.6.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 6
%d.llvm.7.i.i = extractvalue { float, float, float, float, float, float, float, float } %d.llvm.i.i, 7
%dst_ptr.i.i = inttoptr i64 %.fca.1.extract to i8*
call void @llvm.nvvm.wmma.store.d.sync.col.m16n16k16.stride.f32(i8* %dst_ptr.i.i, float %d.llvm.0.i.i, float %d.llvm.1.i.i, float %d.llvm.2.i.i, float %d.llvm.3.i.i, float %d.llvm.4.i.i, float %d.llvm.5.i.i, float %d.llvm.6.i.i, float %d.llvm.7.i.i, i32 16)
ret void
```
</details>

Note that all the `bitcast`, `extractvalue` and `insertvalue` instructions (necessary to generate correct LLVM IR for use with `llvmcall`) are optimised away. The remaining `extractvalue` is necessary to convert the struct return type to separate arguments.

Finally, the NVPTX backend generated the following PTX code:
<details>
<summary>PTX code</summary>

```
LBB1_8:                                 // %julia_kernel_3.exit
	mov.u32 	%r1, 16;
	wmma.load.a.sync.col.m16n16k16.f16 	{%hh1, %hh2, %hh3, %hh4, %hh5, %hh6, %hh7, %hh8}, [%rd1], %r1;
	wmma.load.b.sync.col.m16n16k16.f16 	{%hh9, %hh10, %hh11, %hh12, %hh13, %hh14, %hh15, %hh16}, [%rd2], %r1;
	wmma.load.c.sync.col.m16n16k16.f32 	{%f1, %f2, %f3, %f4, %f5, %f6, %f7, %f8}, [%rd3], %r1;
	wmma.mma.sync.col.col.m16n16k16.f32.f32
		{%f9, %f10, %f11, %f12, %f13, %f14, %f15, %f16},
		{%hh1, %hh2, %hh3, %hh4, %hh5, %hh6, %hh7, %hh8},
		{%hh9, %hh10, %hh11, %hh12, %hh13, %hh14, %hh15, %hh16},
		{%f1, %f2, %f3, %f4, %f5, %f6, %f7, %f8};
	wmma.store.d.sync.col.m16n16k16.f32 	[%rd4], {%f9, %f10, %f11, %f12, %f13, %f14, %f15, %f16}, %r1;
	ret;
```
</details>

**TODO/Questions:**
- ~~I should probably add documentation for this, or should I leave this for the higher-level API?~~ 
- Would you prefer to have the non-stride versions anyway, or can I leave these out?
- ~~The loads and stores are tested with the default address space (generic), using global arrays. Should I add tests for the intrinsic versions with global and shared address spaces as well?~~

Co-authored-by: Thomas Faingnaert <[email protected]>
@bors
Copy link
Contributor

bors bot commented Feb 3, 2020

Build succeeded

@bors bors bot merged commit 93c77bc into JuliaGPU:master Feb 3, 2020
@thomasfaingnaert thomasfaingnaert deleted the wmma-wrapper branch February 3, 2020 12:38
@thomasfaingnaert
Copy link
Member Author

@maleadt Just a heads up, julia:nightly is failing on master, which leads to this PR failing too

@maleadt
Copy link
Member

maleadt commented Feb 3, 2020

Yeah, no worries. JuliaLang/julia#34611 broke a bunch of packages on Julia#master.

@maleadt
Copy link
Member

maleadt commented Feb 3, 2020

Forgot to squash merge so I went ahead and force-pushed master (I know, I know, but otherwise bisecting is broken).

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants