trying out another change
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
|
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||||
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}>
|
#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}>
|
||||||
|
#mma_s1 = #triton_gpu.slice<{dim = 1, parent = #mma}>
|
||||||
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
|
||||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
|
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
|
||||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
@@ -128,17 +129,18 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
%122 = tt.addptr %arg27, %47 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
%122 = tt.addptr %arg27, %47 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||||
scf.yield %105, %120, %96, %121, %122 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x128x!tt.ptr<f16>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
scf.yield %105, %120, %96, %121, %122 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x128x!tt.ptr<f16>, #blocked2>, tensor<128x64x!tt.ptr<f16>, #blocked1>
|
||||||
}
|
}
|
||||||
%58 = triton_gpu.convert_layout %57#2 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked0>
|
%203 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #mma_s1>
|
||||||
%60 = triton_gpu.convert_layout %57#0 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked0>
|
%206 = tt.splat %2 : (i32) -> tensor<128xi32, #mma_s1>
|
||||||
|
%209 = arith.addi %206, %203 : tensor<128xi32, #mma_s1>
|
||||||
%61 = arith.muli %1, %arg21 : i32
|
%61 = arith.muli %1, %arg21 : i32
|
||||||
%62 = tt.addptr %arg4, %61 : !tt.ptr<f32>, i32
|
%62 = tt.addptr %arg4, %61 : !tt.ptr<f32>, i32
|
||||||
%63 = tt.splat %62 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
%63 = tt.splat %62 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #mma_s1>
|
||||||
%64 = tt.addptr %63, %9 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
%64 = tt.addptr %63, %209 : tensor<128x!tt.ptr<f32>, #mma_s1>, tensor<128xi32, #mma_s1>
|
||||||
%65 = tt.addptr %arg5, %61 : !tt.ptr<f32>, i32
|
%65 = tt.addptr %arg5, %61 : !tt.ptr<f32>, i32
|
||||||
%66 = tt.splat %65 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #blocked0>
|
%66 = tt.splat %65 : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>, #mma_s1>
|
||||||
%67 = tt.addptr %66, %9 : tensor<128x!tt.ptr<f32>, #blocked0>, tensor<128xi32, #blocked0>
|
%67 = tt.addptr %66, %209 : tensor<128x!tt.ptr<f32>, #mma_s1>, tensor<128xi32, #mma_s1>
|
||||||
tt.store %64, %60 : tensor<128xf32, #blocked0>
|
tt.store %64, %57#0 : tensor<128xf32, #mma_s1>
|
||||||
tt.store %67, %58 : tensor<128xf32, #blocked0>
|
tt.store %67, %57#2 : tensor<128xf32, #mma_s1>
|
||||||
%68 = arith.muli %1, %arg17 : i32
|
%68 = arith.muli %1, %arg17 : i32
|
||||||
%69 = tt.splat %arg18 : (i32) -> tensor<128x1xi32, #blocked1>
|
%69 = tt.splat %arg18 : (i32) -> tensor<128x1xi32, #blocked1>
|
||||||
%70 = tt.splat %68 : (i32) -> tensor<128x1xi32, #blocked1>
|
%70 = tt.splat %68 : (i32) -> tensor<128x1xi32, #blocked1>
|
||||||
|
Reference in New Issue
Block a user