// TODO: swizzle // TODO: move opIdx = 0 before opIdx = 1 // TODO: reuse %128 in %137 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> // TODO: don't convert loaded value to mma for accumulation // triton-opt unoptimized.ttgir -tritongpu-sink-conversions-from-shared -tritongpu-decompose-conversions-to-dot-operand -cse #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}> #mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}> #shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> #shared2 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 8 : i32} { func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c128_i32 = arith.constant 128 : i32 %c128 = arith.constant 128 : index %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1> %cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0> %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0> %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.divsi %0, %arg22 : i32 %2 = arith.remsi %0, %arg22 : i32 %3 = arith.muli %1, %arg12 : i32 %4 = arith.muli %2, %arg13 : i32 %5 = arith.addi %3, %4 : i32 %6 = tt.addptr %arg0, %5 : !tt.ptr, i32 %7 = tt.addptr %arg1, %5 : !tt.ptr, i32 %8 = tt.addptr %arg2, %5 : !tt.ptr, i32 %9 = tt.addptr %arg5, %5 : !tt.ptr, i32 %10 = tt.addptr %arg6, %5 : !tt.ptr, i32 %11 = tt.addptr %arg7, %5 : !tt.ptr, i32 %12 = tt.addptr %arg8, %5 : !tt.ptr, i32 %13 = arith.index_cast %arg24 : i32 to index %14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1> %20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked2> %21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> %22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> %23 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> %24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> %25 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> %26 = tt.broadcast %25 : (tensor<1x64xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> %27 = tt.splat %6 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> %28 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1> %29 = tt.splat %7 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> %30 = tt.splat %8 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> %31 = tt.splat %9 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> %32 = tt.splat %10 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked2> %33 = arith.muli %0, %arg23 : i32 %34 = tt.addptr %arg11, %33 : !tt.ptr, i32 %35 = tt.addptr %arg10, %33 : !tt.ptr, i32 %36 = arith.muli %arg24, %c128_i32 : i32 %37 = arith.index_cast %36 : i32 to index %38 = tt.splat %35 : (!tt.ptr) -> tensor<128x!tt.ptr, #blocked0> %39 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0> %40 = tt.splat %34 : (!tt.ptr) -> tensor<128x!tt.ptr, #blocked0> %41 = arith.muli %arg14, %c128_i32 : i32 %42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1> %43 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked2> %44 = tt.splat %12 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> %45 = tt.splat %11 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> scf.for %arg25 = %c0 to %13 step %c1 { %46 = arith.index_cast %arg25 : index to i32 %47 = arith.muli %46, %c128_i32 : i32 %48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> %51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> %52 = arith.addi %49, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> %54 = tt.expand_dims %52 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> %55 = arith.muli %53, %28 : tensor<128x1xi32, #blocked1> %56 = tt.broadcast %55 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> %57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1> %58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %60 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> %61 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1> %62 = tt.broadcast %61 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> %63 = arith.addi %62, %24 : tensor<128x64xi32, #blocked1> %64 = tt.addptr %30, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %65 = tt.load %64 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %66 = triton_gpu.convert_layout %65 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> %67 = arith.index_cast %47 : i32 to index %68 = tt.trans %60 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> %69 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> %70 = tt.expand_dims %69 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> %71 = tt.broadcast %70 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> %72 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> %73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2> %74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> %75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2> %76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> %77 = tt.addptr %27, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %78 = tt.addptr %31, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %79:5 = scf.for %arg26 = %67 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>) { %86 = arith.index_cast %arg26 : index to i32 %87 = tt.splat %86 : (i32) -> tensor<128xi32, #blocked0> %88 = tt.splat %86 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %89 = arith.addi %87, %14 : tensor<128xi32, #blocked0> %90 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %91 = triton_gpu.convert_layout %68 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %92 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> %93 = triton_gpu.convert_layout %92 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> %94 = tt.dot %93, %91, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> %95 = arith.addi %88, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0> %97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0> %98 = "triton_gpu.cmpi"(%97, %71) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0> %99 = "triton_gpu.select"(%98, %94, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0> %100 = tt.addptr %38, %89 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> %101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> %102 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %103 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0> %104 = tt.expand_dims %102 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> %105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> %106 = arith.subf %103, %105 : tensor<128x128xf32, #mma0> %107 = math.exp %106 : tensor<128x128xf32, #mma0> %108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> %109 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> %110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> %111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> %112 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %113 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> %114 = triton_gpu.convert_layout %113 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %115 = tt.dot %112, %114, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> %116 = tt.addptr %40, %89 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> %117 = tt.load %116 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> %118 = triton_gpu.convert_layout %117 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> %119 = tt.expand_dims %118 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> %120 = tt.broadcast %119 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> %121 = arith.subf %cst_1, %120 : tensor<128x128xf32, #mma0> %122 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> %123 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> %124 = triton_gpu.convert_layout %123 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> %125 = tt.dot %124, %122, %121 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> %126 = arith.mulf %107, %125 : tensor<128x128xf32, #mma0> %127 = arith.mulf %126, %39 : tensor<128x128xf32, #mma0> %128 = arith.truncf %127 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> %129 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> %130 = tt.trans %129 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> %131 = triton_gpu.convert_layout %90 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> %132 = triton_gpu.convert_layout %131 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %133 = triton_gpu.convert_layout %130 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %134 = tt.dot %133, %132, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> %135 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> %136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1> %137 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> %138 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared2> %139 = triton_gpu.convert_layout %138 : (tensor<128x64xf16, #shared2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> %140 = tt.dot %137, %139, %136 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> %141 = triton_gpu.convert_layout %140 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> tt.store %arg29, %141 : tensor<128x64xf32, #blocked2> %142 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> %143 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %144 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> scf.yield %115, %134, %142, %143, %144 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1> } %80 = arith.truncf %79#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> %81 = tt.addptr %44, %63 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %82 = triton_gpu.convert_layout %80 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> tt.store %81, %82 : tensor<128x64xf16, #blocked1> %83 = arith.truncf %79#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> %84 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %85 = triton_gpu.convert_layout %83 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> tt.store %84, %85 : tensor<128x64xf16, #blocked1> } return } }