From 600bcefb12cbbfe613cdae909aff9af4c9ae1842 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 6 Jan 2023 20:27:49 -0800 Subject: [PATCH] more optimizations --- .../DecomposeConversionsToDotOperand.cpp | 27 ++- .../Transforms/SinkConversionsFromShared.cpp | 36 ++++ python/tutorials/06-fused-attention.py | 55 +++--- python/unoptimized.ttgir | 172 ++++++++++++++++++ 4 files changed, 262 insertions(+), 28 deletions(-) create mode 100644 python/unoptimized.ttgir diff --git a/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp b/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp index 494a5af0f..d7d571019 100644 --- a/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp +++ b/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp @@ -27,7 +27,32 @@ class TritonGPUDecomposeConversionsToDotOperandPass public: TritonGPUDecomposeConversionsToDotOperandPass() = default; - void runOnOperation() override { return; } + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cvtOp.getOperand().getType().cast(); + auto dstType = cvtOp.getType().cast(); + auto srcBlocked = + srcType.getEncoding().dyn_cast(); + auto dstDotOp = + dstType.getEncoding().dyn_cast(); + if (srcBlocked && dstDotOp) { + auto tmpType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::SharedEncodingAttr::get( + mod.getContext(), dstDotOp, srcType.getShape(), + getOrder(srcBlocked), srcType.getElementType())); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getOperand()); + auto newConvert = builder.create( + cvtOp.getLoc(), dstType, tmp); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + } + }); + } }; std::unique_ptr diff --git a/lib/Dialect/TritonGPU/Transforms/SinkConversionsFromShared.cpp b/lib/Dialect/TritonGPU/Transforms/SinkConversionsFromShared.cpp index eba2d7274..62bcae01a 100644 --- a/lib/Dialect/TritonGPU/Transforms/SinkConversionsFromShared.cpp +++ b/lib/Dialect/TritonGPU/Transforms/SinkConversionsFromShared.cpp @@ -21,6 +21,17 @@ using namespace mlir; +static inline bool willIncreaseRegisterPressure(triton::gpu::ConvertLayoutOp op) { + auto srcType = op.getOperand().getType().cast(); + auto dstType = op.getResult().getType().cast(); + auto srcEncoding = srcType.getEncoding(); + auto dstEncoding = dstType.getEncoding(); + if(srcEncoding.isa()) + return true; + if(dstEncoding.isa()) + return true; + return false; +} class TritonGPUSinkConversionsFromSharedPass : public TritonGPUSinkConversionsFromSharedBase { @@ -28,6 +39,31 @@ public: TritonGPUSinkConversionsFromSharedPass() = default; void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + // Move convert(load) immediately after dependent load + m.walk([&](triton::gpu::ConvertLayoutOp op){ + auto load = dyn_cast(op.getOperand().getDefiningOp()); + if(!load) + return; + op->moveAfter(load); + }); + // Sink conversions into loops when they will increase + // register pressure + DenseMap opToMove; + m.walk([&](triton::gpu::ConvertLayoutOp op){ + if(!willIncreaseRegisterPressure(op)) + return; + auto user_begin = op->user_begin(); + auto user_end = op->user_end(); + if(std::distance(user_begin, user_end) != 1) + return; + opToMove.insert({op, *user_begin}); + }); + for(auto &kv: opToMove) + kv.first->moveBefore(kv.second); + + return; } }; diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 79fb70923..97b2482b0 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -191,6 +191,7 @@ def _bwd_kernel( tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) +_bwd_kernel = triton.compile("./being-optimized.ttgir", num_warps=8) # _bwd_kernel = triton.compile("./bwd.ttgir", num_warps=8) # _fwd_kernel = triton.compile("./fails.ptx", num_warps=4, shared=18432) @@ -258,34 +259,34 @@ class _attention(torch.autograd.Function): BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) - # _bwd_kernel[(ctx.grid[1],1,1)]( - # q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, - # o.data_ptr(), do_scaled.data_ptr(), - # dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), - # l.data_ptr(), m.data_ptr(), - # delta.data_ptr(), - # q.stride(0), q.stride(1), q.stride(2), - # k.stride(0), k.stride(1), k.stride(2), - # v.stride(0), v.stride(1), v.stride(2), - # q.shape[0], q.shape[1], q.shape[2], - # ctx.grid[0] - # ) - - pgm = _bwd_kernel[(ctx.grid[1],)]( - q, k, v, ctx.sm_scale, - o, do_scaled, - dq, dk, dv, - l, m, - delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), + _bwd_kernel[(ctx.grid[1],1,1)]( + q.data_ptr(), k.data_ptr(), v.data_ptr(), ctx.sm_scale, + o.data_ptr(), do_scaled.data_ptr(), + dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), + l.data_ptr(), m.data_ptr(), + delta.data_ptr(), + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), q.shape[0], q.shape[1], q.shape[2], - ctx.grid[0], - BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, - num_stages=1, + ctx.grid[0] ) + + # pgm = _bwd_kernel[(ctx.grid[1],)]( + # q, k, v, ctx.sm_scale, + # o, do_scaled, + # dq, dk, dv, + # l, m, + # delta, + # q.stride(0), q.stride(1), q.stride(2), q.stride(3), + # k.stride(0), k.stride(1), k.stride(2), k.stride(3), + # v.stride(0), v.stride(1), v.stride(2), v.stride(3), + # q.shape[0], q.shape[1], q.shape[2], + # ctx.grid[0], + # BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, + # BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, + # num_stages=1, + # ) # print(pgm.asm["ttgir"]) # exit() return dq, dk, dv, None @@ -380,4 +381,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f return ms -# bench_flash_attention.run(save_path='.', print_data=True) +bench_flash_attention.run(save_path='.', print_data=True) diff --git a/python/unoptimized.ttgir b/python/unoptimized.ttgir new file mode 100644 index 000000000..567e7a573 --- /dev/null +++ b/python/unoptimized.ttgir @@ -0,0 +1,172 @@ +// TODO: swizzle +// TODO: move opIdx = 0 before opIdx = 1 +// TODO: reuse %128 in %137 = triton_gpu.convert_layout %127 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> +// TODO: don't convert loaded value to mma for accumulation +// triton-opt unoptimized.ttgir -tritongpu-sink-conversions-from-shared -tritongpu-decompose-conversions-to-dot-operand -cse + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma0 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1]}> +#mma1 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2]}> +#shared0 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +module attributes {"triton_gpu.num-warps" = 8 : i32} { + func public @_bwd_kernel_0d1d2d34d5d6d7d8d9d10d11d12d13d14d15c16d17d18d19c20d21d22d23c2425d26d27(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128_i32 = arith.constant 128 : i32 + %c128 = arith.constant 128 : index + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma1> + %cst_0 = arith.constant dense<0xFF800000> : tensor<128x128xf32, #mma0> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma0> + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.divsi %0, %arg22 : i32 + %2 = arith.remsi %0, %arg22 : i32 + %3 = arith.muli %1, %arg12 : i32 + %4 = arith.muli %2, %arg13 : i32 + %5 = arith.addi %3, %4 : i32 + %6 = tt.addptr %arg0, %5 : !tt.ptr, i32 + %7 = tt.addptr %arg1, %5 : !tt.ptr, i32 + %8 = tt.addptr %arg2, %5 : !tt.ptr, i32 + %9 = tt.addptr %arg5, %5 : !tt.ptr, i32 + %10 = tt.addptr %arg6, %5 : !tt.ptr, i32 + %11 = tt.addptr %arg7, %5 : !tt.ptr, i32 + %12 = tt.addptr %arg8, %5 : !tt.ptr, i32 + %13 = arith.index_cast %arg24 : i32 to index + %14 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %19 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked1> + %20 = tt.splat %arg14 : (i32) -> tensor<128x1xi32, #blocked2> + %21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %23 = tt.expand_dims %21 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> + %24 = tt.broadcast %23 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %25 = tt.expand_dims %22 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>) -> tensor<1x64xi32, #blocked2> + %26 = tt.broadcast %25 : (tensor<1x64xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> + %27 = tt.splat %6 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> + %28 = tt.splat %arg17 : (i32) -> tensor<128x1xi32, #blocked1> + %29 = tt.splat %7 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> + %30 = tt.splat %8 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> + %31 = tt.splat %9 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> + %32 = tt.splat %10 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked2> + %33 = arith.muli %0, %arg23 : i32 + %34 = tt.addptr %arg11, %33 : !tt.ptr, i32 + %35 = tt.addptr %arg10, %33 : !tt.ptr, i32 + %36 = arith.muli %arg24, %c128_i32 : i32 + %37 = arith.index_cast %36 : i32 to index + %38 = tt.splat %35 : (!tt.ptr) -> tensor<128x!tt.ptr, #blocked0> + %39 = tt.splat %arg3 : (f32) -> tensor<128x128xf32, #mma0> + %40 = tt.splat %34 : (!tt.ptr) -> tensor<128x!tt.ptr, #blocked0> + %41 = arith.muli %arg14, %c128_i32 : i32 + %42 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked1> + %43 = tt.splat %41 : (i32) -> tensor<128x64xi32, #blocked2> + %44 = tt.splat %12 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> + %45 = tt.splat %11 : (!tt.ptr) -> tensor<128x64x!tt.ptr, #blocked1> + scf.for %arg25 = %c0 to %13 step %c1 { + %46 = arith.index_cast %arg25 : index to i32 + %47 = arith.muli %46, %c128_i32 : i32 + %48 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %49 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %50 = tt.splat %47 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> + %51 = arith.addi %48, %15 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %52 = arith.addi %49, %16 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %53 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<128x1xi32, #blocked1> + %54 = tt.expand_dims %52 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<128x1xi32, #blocked2> + %55 = arith.muli %53, %28 : tensor<128x1xi32, #blocked1> + %56 = tt.broadcast %55 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %57 = arith.addi %56, %24 : tensor<128x64xi32, #blocked1> + %58 = tt.addptr %29, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %59 = tt.load %58 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %60 = arith.muli %53, %19 : tensor<128x1xi32, #blocked1> + %61 = tt.broadcast %60 : (tensor<128x1xi32, #blocked1>) -> tensor<128x64xi32, #blocked1> + %62 = arith.addi %61, %24 : tensor<128x64xi32, #blocked1> + %63 = tt.addptr %30, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %64 = tt.load %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %65 = arith.index_cast %47 : i32 to index + %66 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> + %67 = tt.trans %66 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> + %68 = arith.addi %50, %17 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>> + %69 = tt.expand_dims %68 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #mma0}>>) -> tensor<1x128xi32, #mma0> + %70 = tt.broadcast %69 : (tensor<1x128xi32, #mma0>) -> tensor<128x128xi32, #mma0> + %71 = triton_gpu.convert_layout %64 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared0> + %72 = tt.trans %71 : (tensor<128x64xf16, #shared0>) -> tensor<64x128xf16, #shared1> + %73 = arith.muli %54, %20 : tensor<128x1xi32, #blocked2> + %74 = tt.broadcast %73 : (tensor<128x1xi32, #blocked2>) -> tensor<128x64xi32, #blocked2> + %75 = arith.addi %74, %26 : tensor<128x64xi32, #blocked2> + %76 = tt.addptr %32, %75 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> + %77 = tt.addptr %27, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %78 = tt.addptr %31, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %79 = triton_gpu.convert_layout %59 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %80:5 = scf.for %arg26 = %65 to %37 step %c128 iter_args(%arg27 = %cst, %arg28 = %cst, %arg29 = %76, %arg30 = %77, %arg31 = %78) -> (tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1>) { + %87 = arith.index_cast %arg26 : index to i32 + %88 = tt.splat %87 : (i32) -> tensor<128xi32, #blocked0> + %89 = tt.splat %87 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %90 = arith.addi %88, %14 : tensor<128xi32, #blocked0> + %91 = tt.load %arg30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %92 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %93 = triton_gpu.convert_layout %67 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> + %94 = tt.dot %92, %93, %cst_1 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> + %95 = arith.addi %89, %18 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %96 = tt.expand_dims %95 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xi32, #mma0> + %97 = tt.broadcast %96 : (tensor<128x1xi32, #mma0>) -> tensor<128x128xi32, #mma0> + %98 = "triton_gpu.cmpi"(%97, %70) {predicate = 5 : i64} : (tensor<128x128xi32, #mma0>, tensor<128x128xi32, #mma0>) -> tensor<128x128xi1, #mma0> + %99 = "triton_gpu.select"(%98, %94, %cst_0) : (tensor<128x128xi1, #mma0>, tensor<128x128xf32, #mma0>, tensor<128x128xf32, #mma0>) -> tensor<128x128xf32, #mma0> + %100 = tt.addptr %38, %90 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> + %101 = tt.load %100 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> + %102 = arith.mulf %99, %39 : tensor<128x128xf32, #mma0> + %103 = triton_gpu.convert_layout %101 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %104 = tt.expand_dims %103 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> + %105 = tt.broadcast %104 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> + %106 = arith.subf %102, %105 : tensor<128x128xf32, #mma0> + %107 = math.exp %106 : tensor<128x128xf32, #mma0> + %108 = tt.load %arg31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1> + %109 = arith.truncf %107 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> + %110 = triton_gpu.convert_layout %109 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> + %111 = tt.trans %110 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> + %112 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %113 = triton_gpu.convert_layout %111 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %114 = tt.dot %113, %112, %arg27 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %115 = tt.addptr %40, %90 : tensor<128x!tt.ptr, #blocked0>, tensor<128xi32, #blocked0> + %116 = tt.load %115 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xf32, #blocked0> + %117 = triton_gpu.convert_layout %116 : (tensor<128xf32, #blocked0>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>> + %118 = tt.expand_dims %117 {axis = 1 : i32} : (tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma0}>>) -> tensor<128x1xf32, #mma0> + %119 = tt.broadcast %118 : (tensor<128x1xf32, #mma0>) -> tensor<128x128xf32, #mma0> + %120 = arith.subf %cst_1, %119 : tensor<128x128xf32, #mma0> + %121 = triton_gpu.convert_layout %108 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> + %122 = triton_gpu.convert_layout %72 : (tensor<64x128xf16, #shared1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> + %123 = tt.dot %121, %122, %120 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma0}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma0}>> -> tensor<128x128xf32, #mma0> + %124 = arith.mulf %107, %123 : tensor<128x128xf32, #mma0> + %125 = arith.mulf %124, %39 : tensor<128x128xf32, #mma0> + %126 = arith.truncf %125 : tensor<128x128xf32, #mma0> to tensor<128x128xf16, #mma0> + %127 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #shared0> + %128 = tt.trans %127 : (tensor<128x128xf16, #shared0>) -> tensor<128x128xf16, #shared1> + %129 = triton_gpu.convert_layout %91 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> + %130 = triton_gpu.convert_layout %128 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %131 = tt.dot %130, %129, %arg28 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %132 = tt.load %arg29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf32, #blocked2> + %133 = triton_gpu.convert_layout %132 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #mma1> + %134 = triton_gpu.convert_layout %126 : (tensor<128x128xf16, #mma0>) -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %135 = tt.dot %134, %79, %133 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1}>> -> tensor<128x64xf32, #mma1> + %136 = triton_gpu.convert_layout %135 : (tensor<128x64xf32, #mma1>) -> tensor<128x64xf32, #blocked2> + tt.store %arg29, %136 : tensor<128x64xf32, #blocked2> + %137 = tt.addptr %arg29, %43 : tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64xi32, #blocked2> + %138 = tt.addptr %arg30, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %139 = tt.addptr %arg31, %42 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + scf.yield %114, %131, %137, %138, %139 : tensor<128x64xf32, #mma1>, tensor<128x64xf32, #mma1>, tensor<128x64x!tt.ptr, #blocked2>, tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64x!tt.ptr, #blocked1> + } + %81 = arith.truncf %80#0 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> + %82 = tt.addptr %44, %62 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %83 = triton_gpu.convert_layout %81 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> + tt.store %82, %83 : tensor<128x64xf16, #blocked1> + %84 = arith.truncf %80#1 : tensor<128x64xf32, #mma1> to tensor<128x64xf16, #mma1> + %85 = tt.addptr %45, %57 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %86 = triton_gpu.convert_layout %84 : (tensor<128x64xf16, #mma1>) -> tensor<128x64xf16, #blocked1> + tt.store %85, %86 : tensor<128x64xf16, #blocked1> + } + return + } +} \ No newline at end of file