[Triton-IR] Added type inference and verifier for Triton-IR operations (#767)
This commit is contained in:
@@ -348,15 +348,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v4
|
||||
func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #block1>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #block1>) -> tensor<1x64xi32, #block3>
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<64xi32, #slice3d0>) -> tensor<1x64xi32, #block3>
|
||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x64xi32, #block2>
|
||||
%cst_scalar = arith.constant 64 : i32
|
||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x64xi32, #block2>
|
||||
@@ -387,15 +389,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#block1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v1
|
||||
func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #block0>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block1>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #block0>) -> tensor<16x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block1>) -> tensor<1x32xi32, #block3>
|
||||
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #slice3d0>) -> tensor<1x32xi32, #block3>
|
||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<16x1xi32, #block2>) -> tensor<16x32xi32, #block2>
|
||||
%cst_scalar = arith.constant 32 : i32
|
||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<16x32xi32, #block2>
|
||||
@@ -429,15 +433,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#block0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [8], warpsPerCTA = [4], order = [0]}>
|
||||
#block2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#block3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 8], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#slice2d1 = #triton_gpu.slice<{dim = 1, parent=#block2}>
|
||||
#slice3d0 = #triton_gpu.slice<{dim = 0, parent=#block3}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_insert_slice_async_v1_multictas
|
||||
func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
|
||||
%off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block0>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #block0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #block0>) -> tensor<32x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #block0>) -> tensor<1x32xi32, #block3>
|
||||
%off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1>
|
||||
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
|
||||
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #slice2d1>) -> tensor<32x1xi32, #block2>
|
||||
%off1 = tt.expand_dims %off1_ {axis = 0 : i32} : (tensor<32xi32, #slice3d0>) -> tensor<1x32xi32, #block3>
|
||||
%broadcast_off0_scalar = tt.broadcast %off0 : (tensor<32x1xi32, #block2>) -> tensor<32x32xi32, #block2>
|
||||
%cst_scalar = arith.constant 32 : i32
|
||||
%cst = tt.splat %cst_scalar : (i32) -> tensor<32x32xi32, #block2>
|
||||
|
Reference in New Issue
Block a user