From 9feb256b7186efe6d86c3f7b11cf89f10e1f91e9 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Fri, 17 Jun 2022 16:19:47 +0800 Subject: [PATCH] op combine in Triton Dialect: broadcast(cst) -> cst --- lib/Conversion/CMakeLists.txt | 1 + lib/Dialect/Triton/Transforms/Combine.cpp | 36 +++++- rewrite-test/jit/vecadd/vecadd-loop.py | 2 +- rewrite-test/jit/vecadd/vecadd.mlir | 130 +++++++--------------- 4 files changed, 76 insertions(+), 93 deletions(-) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 5cbcea5da..a08349513 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1 +1,2 @@ +# add_subdirectory(TritonGPUToLLVM) add_subdirectory(TritonToTritonGPU) diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index aaed058b7..2fc073c05 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -10,7 +10,7 @@ #include -// using namespace mlir; +using namespace mlir; namespace { // dot(a, b, 0) + c => dot(a, b, c) @@ -114,6 +114,39 @@ public: return mlir::failure(); } }; + +// broadcast(cst) => cst +// TODO: move this to .td file +class CombineBroadcastConstantOp : public mlir::RewritePattern { +public: + CombineBroadcastConstantOp(mlir::MLIRContext *context) + : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, + context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (auto broadcast = llvm::dyn_cast(op)) { + if (auto cst = broadcast.src().getDefiningOp()) { + Attribute value = cst.getValue(); + Type resType = broadcast.getResult().getType(); + if (auto denseValue = value.dyn_cast()) { + if (!denseValue.isSplat()) + return failure(); + value = DenseElementsAttr::get(resType, denseValue.getSplatValue()); + } else { + if (!value.isa()) + return failure(); + value = DenseElementsAttr::get(resType, value); + } + rewriter.replaceOpWithNewOp( + op, value, resType + ); + return success(); + } + } + return failure(); + } +}; } // anonymous namespace #define GEN_PASS_CLASSES @@ -129,6 +162,7 @@ public: patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); // patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) diff --git a/rewrite-test/jit/vecadd/vecadd-loop.py b/rewrite-test/jit/vecadd/vecadd-loop.py index 069875cc9..643af7360 100644 --- a/rewrite-test/jit/vecadd/vecadd-loop.py +++ b/rewrite-test/jit/vecadd/vecadd-loop.py @@ -47,7 +47,7 @@ z = torch.empty_like(x) # add_kernel[(1,)](x, y, z, size, 256) # print(add_kernel[(1,)].kernel.compile_to_ttir()) mod, ctx = add_kernel.compile_to_ttir( - x, y, z, size, 128, 8, grid=(1,), num_stages=4) + x, y, z, size, 128, 8, grid=(1,), num_stages=1) mod.dump() # print(mod) diff --git a/rewrite-test/jit/vecadd/vecadd.mlir b/rewrite-test/jit/vecadd/vecadd.mlir index 4148ec28a..07a216925 100644 --- a/rewrite-test/jit/vecadd/vecadd.mlir +++ b/rewrite-test/jit/vecadd/vecadd.mlir @@ -9,9 +9,9 @@ module { %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32> %6 = arith.cmpi slt, %4, %5 : tensor<256xi32> %7 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr> - %8 = tt.getelementptr %7, %4, : tensor<256x!tt.ptr> + %8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr> %9 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr> - %10 = tt.getelementptr %9, %4, : tensor<256x!tt.ptr> + %10 = tt.getelementptr %9, %4 : tensor<256x!tt.ptr> %cst = arith.constant 0.000000e+00 : f32 %11 = tt.broadcast %cst : (f32) -> tensor<256xf32> %c0_i32 = arith.constant 0 : i32 @@ -29,100 +29,48 @@ module { %22 = arith.addf %19, %21 : tensor<256xf32> %23 = arith.addf %arg7, %22 : tensor<256xf32> %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32> - %25 = tt.getelementptr %arg8, %24, : tensor<256x!tt.ptr> + %25 = tt.getelementptr %arg8, %24 : tensor<256x!tt.ptr> %26 = tt.broadcast %arg5 : (i32) -> tensor<256xi32> - %27 = tt.getelementptr %arg9, %26, : tensor<256x!tt.ptr> + %27 = tt.getelementptr %arg9, %26 : tensor<256x!tt.ptr> scf.yield %23, %25, %27 : tensor<256xf32>, tensor<256x!tt.ptr>, tensor<256x!tt.ptr> } %16 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr> - %17 = tt.getelementptr %16, %4, : tensor<256x!tt.ptr> + %17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr> tt.store %17, %15#0, %6, : tensor<256xf32> return } } -// module { -// func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { -// %c64 = arith.constant 64 : index -// %c32 = arith.constant 32 : index -// %c0 = arith.constant 0 : index -// %cst = arith.constant 0.000000e+00 : f32 -// %c256_i32 = arith.constant 256 : i32 -// %0 = tt.get_program_id {axis = 0 : i32} : i32 -// %1 = arith.muli %0, %c256_i32 : i32 -// %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu<"coalesced encoding">>, tensor<256xi32, #triton_gpu<"coalesced encoding">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %7 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %8 = tt.getelementptr %7, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %9 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %10 = tt.getelementptr %9, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %11 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %12 = arith.index_cast %arg4 : i32 to index -// %13 = arith.cmpi slt, %c0, %12 : index -// %14 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %15 = tt.broadcast %13 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %16 = arith.andi %6, %15 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %17 = triton_gpu.copy_async %8, %16, %14 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %18 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %19 = tt.broadcast %13 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %20 = arith.andi %6, %19 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %21 = triton_gpu.copy_async %10, %20, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %22 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %23 = tt.getelementptr %8, %22, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %25 = tt.getelementptr %10, %24, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %26 = arith.cmpi slt, %c32, %12 : index -// %27 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %28 = tt.broadcast %26 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %29 = arith.andi %6, %28 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %30 = triton_gpu.copy_async %23, %29, %27 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %31 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %32 = tt.broadcast %26 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %33 = arith.andi %6, %32 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %34 = triton_gpu.copy_async %25, %33, %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %35 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %36 = tt.getelementptr %23, %35, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %37 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %38 = tt.getelementptr %25, %37, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %39 = arith.cmpi slt, %c64, %12 : index -// %40 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %41 = tt.broadcast %39 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %42 = arith.andi %6, %41 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %43 = triton_gpu.copy_async %36, %42, %40 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %44 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %45 = tt.broadcast %39 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %46 = arith.andi %6, %45 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %47 = triton_gpu.copy_async %38, %46, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %48 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %49 = tt.getelementptr %36, %48, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %50 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %51 = tt.getelementptr %38, %50, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, index) { -// %55 = arith.addf %arg10, %arg13 : tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %56 = arith.addf %arg7, %55 : tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %57 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %58 = tt.getelementptr %arg8, %57, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %59 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %60 = tt.getelementptr %arg9, %59, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %61 = arith.addi %arg18, %c32 : index -// %62 = arith.cmpi slt, %61, %12 : index -// %63 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %64 = tt.broadcast %62 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %65 = arith.andi %64, %6 : tensor<256xi1, #triton_gpu<"coalesced encoding">> -// %66 = triton_gpu.copy_async %arg17, %65, %63 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %67 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %68 = triton_gpu.copy_async %arg16, %65, %67 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -> tensor<256xf32, #triton_gpu<"coalesced encoding">> -// %69 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %70 = tt.getelementptr %arg17, %69, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %71 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding">> -// %72 = tt.getelementptr %arg16, %71, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256xf32, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">>, index -// } -// %53 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// %54 = tt.getelementptr %53, %4, : tensor<256x!tt.ptr, #triton_gpu<"coalesced encoding">> -// tt.store %54, %52#0, %6, : tensor<256xf32, #triton_gpu<"coalesced encoding">> -// return -// } -// } +module { + func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %c256_i32 = arith.constant 256 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %7 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %9 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %10 = tt.getelementptr %9, %4 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %11 = arith.index_cast %arg4 : i32 to index + %12:3 = scf.for %arg6 = %c0 to %11 step %c32 iter_args(%arg7 = %cst, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>) { + %15 = tt.load %arg8, %6, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %16 = tt.load %arg9, %6, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %17 = arith.addf %15, %16 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %18 = arith.addf %arg7, %17 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %19 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %20 = tt.getelementptr %arg8, %19 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %21 = tt.getelementptr %arg9, %19 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + scf.yield %18, %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + } + %13 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %14 = tt.getelementptr %13, %4 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + tt.store %14, %12#0, %6, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + return + } +}