trying out another change

This commit is contained in:
Phil Tillet
2022-12-27 21:51:51 -08:00
parent eefc9d1274
commit 7aba2a60d6

View File

@@ -2,6 +2,7 @@
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}> #mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}>
#mma_s1 = #triton_gpu.slice<{dim = 1, parent = #mma}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} {
@@ -128,17 +129,18 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%122 = tt.addptr %arg27, %47 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> %122 = tt.addptr %arg27, %47 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
scf.yield %105, %120, %96, %121, %122 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x128x!tt.ptr<f16>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1> scf.yield %105, %120, %96, %121, %122 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x128x!tt.ptr<f16>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>
} }
%58 = triton_gpu.convert_layout %57#2 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked0> %203 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #mma_s1>
%60 = triton_gpu.convert_layout %57#0 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked0> %206 = tt.splat %2 : (i32) -> tensor<128xi32, #mma_s1>
%209 = arith.addi %206, %203 : tensor<128xi32, #mma_s1>
%61 = arith.muli %1, %arg21 : i32 %61 = arith.muli %1, %arg21 : i32
%62 = tt.addptr %arg4, %61 : !tt.ptr<f32>, i32 %62 = tt.addptr %arg4, %61 : !tt.ptr<f32>, i32
%63 = tt.splat %62 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0> %63 = tt.splat %62 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #mma_s1>
%64 = tt.addptr %63, %9 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0> %64 = tt.addptr %63, %209 : tensor<128x!tt.ptr<f32>, #mma_s1>, tensor<128xi32, #mma_s1>
%65 = tt.addptr %arg5, %61 : !tt.ptr<f32>, i32 %65 = tt.addptr %arg5, %61 : !tt.ptr<f32>, i32
%66 = tt.splat %65 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0> %66 = tt.splat %65 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #mma_s1>
%67 = tt.addptr %66, %9 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0> %67 = tt.addptr %66, %209 : tensor<128x!tt.ptr<f32>, #mma_s1>, tensor<128xi32, #mma_s1>
tt.store %64, %60 : tensor<128xf32, #blocked0> tt.store %64, %57#0 : tensor<128xf32, #mma_s1>
tt.store %67, %58 : tensor<128xf32, #blocked0> tt.store %67, %57#2 : tensor<128xf32, #mma_s1>
%68 = arith.muli %1, %arg17 : i32 %68 = arith.muli %1, %arg17 : i32
%69 = tt.splat %arg18 : (i32) -> tensor<128x1xi32, #blocked1> %69 = tt.splat %arg18 : (i32) -> tensor<128x1xi32, #blocked1>
%70 = tt.splat %68 : (i32) -> tensor<128x1xi32, #blocked1> %70 = tt.splat %68 : (i32) -> tensor<128x1xi32, #blocked1>