[BACKEND] Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp & GEPOp and bugfix for SplatOp, StoreOp, FuncOp (#60)
Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp, GEPOp and bugfix for SplatOp, StoreOp, FuncOp Co-authored-by: gzhu <gzhu@nvidia.com>
This commit is contained in:
@@ -19,6 +19,8 @@ func @test_splat(%ptr: !tt.ptr<f32>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
%a = arith.constant 1.0 : f32
|
||||
@@ -27,9 +29,8 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
||||
%vs = tt.splat %a : (f32) -> tensor<128xf32>
|
||||
%mask = tt.splat %true : (i1) -> tensor<128xi1>
|
||||
|
||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@%0 st.global.v32.b1 [ %1 + 0 ], { %2 };",
|
||||
// CHECK: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.struct<(i1, i1)>, !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>, i32) -> !llvm.struct<()>
|
||||
|
||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.v32.b1 [ $1 + 0 ], { $2 };",
|
||||
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32>
|
||||
|
||||
return
|
||||
|
@@ -112,28 +112,81 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
}
|
||||
}
|
||||
|
||||
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
// #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||
// #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
// module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// func @debut_kernel(%lb : index, %A : !tt.ptr<f32>, %B : !tt.ptr<f32>, %C : !tt.ptr<f32>) {
|
||||
// %cst = arith.constant dense<true> : tensor<256xi1, #blocked0>
|
||||
// %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
|
||||
// %cst_1 = arith.constant dense<true> : tensor<1024x256xi1, #blocked1>
|
||||
// %cst_2 = arith.constant dense<true> : tensor<256x2048xi1, #blocked2>
|
||||
// %a_ptr_init = tt.splat %A : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
// %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
// %4 = tt.view %1 : (tensor<256xf32, #blocked0>) -> tensor<1x256xf32,#blocked1>
|
||||
// %5 = tt.broadcast %4 : (tensor<1x256xf32,#blocked1>) -> tensor<1024x256xf32, #blocked1>
|
||||
// %6 = tt.view %1 : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
|
||||
// %7 = tt.broadcast %6 : (tensor<256x1xf32,#blocked2>) -> tensor<256x2048xf32, #blocked2>
|
||||
// %b_ptr_init = tt.splat %A : (!tt.ptr<f32>) -> tensor<1024x256x!tt.ptr<f32>, #blocked1>
|
||||
// %c_ptr_init = tt.splat %A : (!tt.ptr<f32>) -> tensor<256x2048x!tt.ptr<f32>, #blocked2>
|
||||
// tt.store %b_ptr_init, %5, %cst_1, : tensor<1024x256xf32, #blocked1>
|
||||
// tt.store %c_ptr_init, %7, %cst_2, : tensor<256x2048xf32, #blocked2>
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_addf
|
||||
func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
|
||||
// CHECK: llvm.fadd
|
||||
// CHECK: llvm.fadd
|
||||
%1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_addi
|
||||
func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK: llvm.add
|
||||
// CHECK: llvm.add
|
||||
%1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_program_id
|
||||
func @basic_program_id() {
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_gep
|
||||
func @basic_gep(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK: llvm.getelementptr
|
||||
// CHECK: llvm.getelementptr
|
||||
%0 = tt.getelementptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: basic_splat
|
||||
func @basic_splat(%ptr: !tt.ptr<f32>) {
|
||||
// CHECK: llvm.mlir.undef
|
||||
// CHECK: llvm.insertvalue
|
||||
// CHECK: llvm.insertvalue
|
||||
%0 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>,#blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_store
|
||||
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||
// CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
|
||||
tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user