diff --git a/python/bwd.ttgir b/python/bwd.ttgir index c79d41631..e91e74f60 100644 --- a/python/bwd.ttgir +++ b/python/bwd.ttgir @@ -95,13 +95,13 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %77 = tt.addptr %27, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %78 = tt.addptr %31, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %79:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>) { - %91 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %86 = arith.index_cast %arg26 : index to i32 %87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0> %88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %89 = arith.addi %87, %14 : tensor<128xi32, #blocked0> %90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %92 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %91 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %93 = tt.dot %92, %91, %cst_0 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> %94 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %95 = tt.expand_dims %94 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0>