[Triton-MLIR][BACKEND] Refine/add codegen for get_promgram_id and get_num_programs Op (#877)
This commit is contained in:
@@ -849,3 +849,45 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
%blockidx = tt.get_program_id {axis=0:i32} : i32
|
||||
%blockidy = tt.get_program_id {axis=1:i32} : i32
|
||||
%blockidz = tt.get_program_id {axis=2:i32} : i32
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.x
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.y
|
||||
// CHECK: nvvm.read.ptx.sreg.ctaid.z
|
||||
%v0 = arith.addi %blockidx, %blockidy : i32
|
||||
%v1 = arith.addi %v0, %blockidz : i32
|
||||
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
|
||||
tt.store %a, %0 : tensor<32xi32, #blocked0>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
|
||||
// CHECK: nvvm.read.ptx.sreg.ntid.x
|
||||
// CHECK: nvvm.read.ptx.sreg.ntid.y
|
||||
// CHECK: nvvm.read.ptx.sreg.ntid.z
|
||||
%blockdimx = tt.get_num_programs {axis=0:i32} : i32
|
||||
%blockdimy = tt.get_num_programs {axis=1:i32} : i32
|
||||
%blockdimz = tt.get_num_programs {axis=2:i32} : i32
|
||||
%v0 = arith.addi %blockdimx, %blockdimy : i32
|
||||
%v1 = arith.addi %v0, %blockdimz : i32
|
||||
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
|
||||
tt.store %a, %0 : tensor<32xi32, #blocked0>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user