[BACKEND] add dot conversion (mma version=2) (#672)
LLVM Conversion for Dot op. Due to the lack of `convert_layout`, currently, the dot only supports the following combination of operands - `$a` in shared layout - `$b` in shared layout - `$c` in MMA layout(but only Splat-like, leaving the generic cases to `convert_layout`) This PR focus on `mma.16816` related logic support, leaving the other cases to the following PR. Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
|
||||
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}}
|
||||
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}}
|
||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
// CHECK: llvm.return
|
||||
return
|
||||
@@ -422,6 +422,33 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
|
||||
#mma0 = #triton_gpu.mma<{version=2, warpsPerCTA=[1,1]}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_dot
|
||||
func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
|
||||
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4
|
||||
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
|
||||
%D = tt.dot %AA, %BB, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0>
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO: problems in MLIR's parser on slice layout
|
||||
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
@@ -429,4 +456,4 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
Reference in New Issue
Block a user