diff --git a/python/flash-attention.ttgir b/python/flash-attention.ttgir index 7ce40fbc7..6ff6b8da0 100644 --- a/python/flash-attention.ttgir +++ b/python/flash-attention.ttgir @@ -2,6 +2,7 @@ #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}> +#mma_s1 = #triton_gpu.slice<{dim = 1, parent = #mma}> #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { @@ -128,17 +129,18 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %122 = tt.addptr %arg27, %47 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> scf.yield %105, %120, %96, %121, %122 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128x64xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x128x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1> } - %58 = triton_gpu.convert_layout %57#2 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked0> - %60 = triton_gpu.convert_layout %57#0 : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf32, #blocked0> + %203 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #mma_s1> + %206 = tt.splat %2 : (i32) -> tensor<128xi32, #mma_s1> + %209 = arith.addi %206, %203 : tensor<128xi32, #mma_s1> %61 = arith.muli %1, %arg21 : i32 %62 = tt.addptr %arg4, %61 : !tt.ptr, i32 - %63 = tt.splat %62 : (!tt.ptr) -> tensor<128x!tt.ptr, #blocked0> - %64 = tt.addptr %63, %9 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> + %63 = tt.splat %62 : (!tt.ptr) -> tensor<128x!tt.ptr, #mma_s1> + %64 = tt.addptr %63, %209 : tensor<128x!tt.ptr, #mma_s1>, tensor<128xi32, #mma_s1> %65 = tt.addptr %arg5, %61 : !tt.ptr, i32 - %66 = tt.splat %65 : (!tt.ptr) -> tensor<128x!tt.ptr, #blocked0> - %67 = tt.addptr %66, %9 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> - tt.store %64, %60 : tensor<128xf32, #blocked0> - tt.store %67, %58 : tensor<128xf32, #blocked0> + %66 = tt.splat %65 : (!tt.ptr) -> tensor<128x!tt.ptr, #mma_s1> + %67 = tt.addptr %66, %209 : tensor<128x!tt.ptr, #mma_s1>, tensor<128xi32, #mma_s1> + tt.store %64, %57#0 : tensor<128xf32, #mma_s1> + tt.store %67, %57#2 : tensor<128xf32, #mma_s1> %68 = arith.muli %1, %arg17 : i32 %69 = tt.splat %arg18 : (i32) -> tensor<128x1xi32, #blocked1> %70 = tt.splat %68 : (i32) -> tensor<128x1xi32, #blocked1>