diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 00d0ebd1f..ff9daa08c 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -23,6 +23,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %c0_i32 = arith.constant 0 : i32 %c2_i32 = arith.constant 2 : i32 %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index %cst = arith.constant dense<64> : tensor<128x64xi32, #blocked0> %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> %c8_i32 = arith.constant 8 : i32 @@ -102,11 +103,11 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %66 = tensor.extract_slice %61[0, 0, 0] [1, 128, 64] [1, 1, 1] : tensor<3x128x64xf16, #shared> to tensor<128x64xf16, #shared> %67 = tensor.extract_slice %63[0, 0, 0] [1, 64, 256] [1, 1, 1] : tensor<3x64x256xf16, #shared> to tensor<64x256xf16, #shared> %68 = tensor.extract_slice %66[0, 0] [128, 16] [1, 1] : tensor<128x64xf16, #shared> to tensor<128x16xf16, #shared> - %69 = triton_gpu.convert_layout %68 : (tensor<128x16xf16, #shared>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> %70 = tensor.extract_slice %67[0, 0] [16, 256] [1, 1] : tensor<64x256xf16, #shared> to tensor<16x256xf16, #shared> - %71 = triton_gpu.convert_layout %70 : (tensor<16x256xf16, #shared>) -> tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> - %72:14 = scf.for %arg9 = %c0 to %39 step %c64 iter_args(%arg10 = %cst_0, %arg11 = %45, %arg12 = %49, %arg13 = %61, %arg14 = %63, %arg15 = %66, %arg16 = %67, %arg17 = %64, %arg18 = %65, %arg19 = %c64, %arg20 = %c2_i32, %arg21 = %c1_i32, %arg22 = %69, %arg23 = %71) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked0>, tensor<64x256x!tt.ptr, #blocked1>, tensor<3x128x64xf16, #shared>, tensor<3x64x256xf16, #shared>, tensor<128x64xf16, #shared>, tensor<64x256xf16, #shared>, tensor<128x64x!tt.ptr, #blocked0>, tensor<64x256x!tt.ptr, #blocked1>, index, i32, i32, tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>) { - %89 = tt.dot %arg22, %arg23, %arg10 {allowTF32 = true, transA = false, transB = false} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x256xf32, #mma> + %72:14 = scf.for %arg9 = %c0 to %39 step %c64 iter_args(%arg10 = %cst_0, %arg11 = %45, %arg12 = %49, %arg13 = %61, %arg14 = %63, %arg15 = %66, %arg16 = %67, %arg17 = %64, %arg18 = %65, %arg19 = %c64, %arg20 = %c2_i32, %arg21 = %c1_i32, %arg22 = %68, %arg23 = %70) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked0>, tensor<64x256x!tt.ptr, #blocked1>, tensor<3x128x64xf16, #shared>, tensor<3x64x256xf16, #shared>, tensor<128x64xf16, #shared>, tensor<64x256xf16, #shared>, tensor<128x64x!tt.ptr, #blocked0>, tensor<64x256x!tt.ptr, #blocked1>, index, i32, i32, tensor<128x16xf16, #shared>, tensor<16x256xf16, #shared>) { + %69 = triton_gpu.convert_layout %arg22 : (tensor<128x16xf16, #shared>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %71 = triton_gpu.convert_layout %arg23 : (tensor<16x256xf16, #shared>) -> tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> + %89 = tt.dot %69, %71, %arg10 {allowTF32 = true, transA = false, transB = false} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x256xf32, #mma> %90 = tensor.extract_slice %arg15[0, 16] [128, 32] [1, 1] : tensor<128x64xf16, #shared> to tensor<128x32xf16, #shared> %91 = triton_gpu.convert_layout %90 : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> %92 = tensor.extract_slice %arg16[16, 0] [32, 256] [1, 1] : tensor<64x256xf16, #shared> to tensor<32x256xf16, #shared> @@ -136,10 +137,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { %115 = arith.addi %arg20, %c1_i32 : i32 %116 = arith.addi %arg21, %c1_i32 : i32 %117 = tensor.extract_slice %113[0, 0] [128, 16] [1, 1] : tensor<128x64xf16, #shared> to tensor<128x16xf16, #shared> - %118 = triton_gpu.convert_layout %117 : (tensor<128x16xf16, #shared>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> %119 = tensor.extract_slice %114[0, 0] [16, 256] [1, 1] : tensor<64x256xf16, #shared> to tensor<16x256xf16, #shared> - %120 = triton_gpu.convert_layout %119 : (tensor<16x256xf16, #shared>) -> tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> - scf.yield %99, %100, %101, %108, %110, %113, %114, %111, %112, %102, %115, %116, %118, %120 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked0>, tensor<64x256x!tt.ptr, #blocked1>, tensor<3x128x64xf16, #shared>, tensor<3x64x256xf16, #shared>, tensor<128x64xf16, #shared>, tensor<64x256xf16, #shared>, tensor<128x64x!tt.ptr, #blocked0>, tensor<64x256x!tt.ptr, #blocked1>, index, i32, i32, tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, tensor<16x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> + scf.yield %99, %100, %101, %108, %110, %113, %114, %111, %112, %102, %115, %116, %117, %119 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked0>, tensor<64x256x!tt.ptr, #blocked1>, tensor<3x128x64xf16, #shared>, tensor<3x64x256xf16, #shared>, tensor<128x64xf16, #shared>, tensor<64x256xf16, #shared>, tensor<128x64x!tt.ptr, #blocked0>, tensor<64x256x!tt.ptr, #blocked1>, index, i32, i32, tensor<128x16xf16, #shared>, tensor<16x256xf16, #shared> } triton_gpu.async_wait {num = 0 : i32} %73 = triton_gpu.convert_layout %72#0 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked1>