From 29a1e20b585ce0d5f4950bc1819792730f9b9ca5 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Wed, 4 Jan 2023 17:02:31 -0800 Subject: [PATCH] tweak convert + trans --- python/bwd.ttgir | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/bwd.ttgir b/python/bwd.ttgir index 5ac66c050..5ac495943 100644 --- a/python/bwd.ttgir +++ b/python/bwd.ttgir @@ -81,13 +81,13 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %63 = tt.addptr %30, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %64 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %65 = arith.index_cast %47 : i32 to index - %66 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> - %67 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> + %66 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1> + %67 = tt.trans %66 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared0> %68 = arith.addi %49, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> %69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> %70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> - %71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> - %72 = tt.trans %71 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> + %71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1> + %72 = tt.trans %71 : (tensor<128x64xf16, #shared1>) -> tensor<64x128xf16, #shared0> %73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2> %74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> %75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2> @@ -102,7 +102,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %900 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %90 = triton_gpu.convert_layout %900 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared1> %92 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> - %91 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> + %91 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared0>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %93 = tt.dot %92, %91, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> %94 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %95 = tt.expand_dims %94 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0> @@ -132,7 +132,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %118 = tt.broadcast %117 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> %119 = arith.subf %cst_0, %118 : tensor<128x128xf32, #mma0> %120 = triton_gpu.convert_layout %107 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> - %121 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> + %121 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared0>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %122 = tt.dot %120, %121, %119 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> %123 = arith.mulf %106, %122 : tensor<128x128xf32, #mma0> %124 = arith.mulf %123, %39 : tensor<128x128xf32, #mma0> @@ -145,7 +145,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %131 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> %132 = triton_gpu.convert_layout %131 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1> %133 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %134 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #shared0>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %134 = triton_gpu.convert_layout %66 : (tensor<128x64xf16, #shared1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %135 = tt.dot %133, %134, %132 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> %136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> tt.store %arg29, %136 : tensor<128x64xf32, #blocked2>