From e3916c3a464330f3110a024be69ae76e46b6df70 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Mon, 16 May 2022 19:16:01 +0800 Subject: [PATCH] TritonGPU combiner --- include/triton/Dialect/Triton/IR/TritonOps.td | 58 +-------- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 2 +- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 14 ++- .../Dialect/TritonGPU/Transforms/Passes.h | 6 + .../Dialect/TritonGPU/Transforms/Passes.td | 19 +++ .../TritonGPU/Transforms/CMakeLists.txt | 6 + python/src/triton.cc | 3 + python/triton/code_gen.py | 1 + rewrite-test/jit/matmul/matmul.mlir | 118 +++++++++--------- 9 files changed, 109 insertions(+), 118 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index ef0126f3a..f137bf1ef 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -1,9 +1,9 @@ -#ifndef Triton_OPS -#define Triton_OPS +#ifndef TRITON_OPS +#define TRITON_OPS include "triton/Dialect/Triton/IR/TritonDialect.td" include "triton/Dialect/Triton/IR/TritonTypes.td" -include "mlir/IR/EnumAttr.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface @@ -64,25 +64,6 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect, // // Load/Store Ops // -def TT_CacheModifierAttr : I32EnumAttr< - "CacheModifier", "", - [ - I32EnumAttrCase<"NONE", 1, "none">, - I32EnumAttrCase<"CA", 2, "ca">, - I32EnumAttrCase<"CG", 3, "cg">, - ]> { - let cppNamespace = "::mlir::triton"; -} -def TT_EvictionPolicyAttr : I32EnumAttr< - "EvictionPolicy", "", - [ - I32EnumAttrCase<"NORMAL", 1, "normal">, - I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, - I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> - ]> { - let cppNamespace = "::mlir::triton"; -} - def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape, MemoryEffects<[MemRead]>, @@ -221,45 +202,12 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect, // let hasCanonicalizer = 1; } -// reduction -def TT_RedOpAttr : I32EnumAttr< - /*name*/"RedOp", /*summary*/"", - /*case*/ - [ - I32EnumAttrCase, - I32EnumAttrCase<"MAX", 2, "max">, - I32EnumAttrCase<"MIN", 3, "min">, - I32EnumAttrCase<"FADD", 4, "fadd">, - I32EnumAttrCase<"FMAX", 5, "fmax">, - I32EnumAttrCase<"FMIN", 6, "fmin">, - I32EnumAttrCase<"XOR", 7, "xor"> - ]> { - let cppNamespace = "::mlir::triton"; -} - def TT_ReduceOp : TT_Op<"reduce"> { let summary = "reduce"; let arguments = (ins TT_RedOpAttr:$reduce_op, TT_Type:$operand, I32Attr:$axis); } -// atomic -def TT_AtomicRMWAttr : I32EnumAttr< - "RMWOp", "", - [ - I32EnumAttrCase<"AND", 1, "and">, - I32EnumAttrCase<"OR", 2, "or">, - I32EnumAttrCase<"XOR", 3, "xor">, - I32EnumAttrCase<"ADD", 4, "add">, - I32EnumAttrCase<"FADD", 5, "fadd">, - I32EnumAttrCase<"MAX", 6, "max">, - I32EnumAttrCase<"MIN", 7, "min">, - I32EnumAttrCase<"UMAX", 8, "umax">, - I32EnumAttrCase<"UMIN", 9, "umin"> - ]> { - let cppNamespace = "::mlir::triton"; -} - def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> { let summary = "atomic rmw"; diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 6368c0056..8fd1cd661 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -1,7 +1,7 @@ #ifndef TRITONGPU_ATTRDEFS #define TRITONGPU_ATTRDEFS -include "TritonGPUDialect.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" // include "mlir/IR/TensorEncoding.td" class TritonGPU_Attr traits = []> diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 8dd25c94b..ed1a78b85 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -5,6 +5,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType @@ -33,7 +34,18 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { let arguments = (ins I32Attr:$num); } -// def TTG_CopyAsyncOp : TTG_Op<"copy_async"> {} +def TTG_CopyAsyncOp : TTG_Op<"copy_async", + [MemoryEffects<[MemRead, MemWrite]>]> { + let summary = "copy async"; + + let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other, + TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, + BoolAttr:$isVolatile); + + let results = (outs TT_Type:$result); + + // let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($result)"; +} // Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU. def TTG_CmpIOp : TTG_Op<"cmpi"> { diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index 870714504..1fa150a60 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -6,6 +6,12 @@ namespace mlir { std::unique_ptr createTritonGPUPipelinePass(); +namespace triton { +namespace gpu { +std::unique_ptr createCombineOpsPass(); +} +} + // /// Generate the code for registering passes. // #define GEN_PASS_REGISTRATION // #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 827e74485..540bedd72 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -26,4 +26,23 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { "mlir::arith::ArithmeticDialect"]; } +def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { + let summary = "combine triton gpu ops"; + + let description = [{ + convert_layout(load(%ptr, %mask, %other), #SMEM_LAYOUT) => + copy_async(%ptr, %mask, %other), barrier + + convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) => + convert_layout(%src, #LAYOUT_1) + + convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT + }]; + + let constructor = "mlir::triton::gpu::createCombineOpsPass"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + #endif diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index ff3814a87..b803bd30b 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -1,9 +1,15 @@ +set(LLVM_TARGET_DEFINITIONS Combine.td) +mlir_tablegen(TritonGPUCombine.inc -gen-rewriters) +add_public_tablegen_target(TritonGPUCombineIncGen) + add_mlir_dialect_library(TritonGPUTransforms + Combine.cpp Pipeline.cpp TritonGPUConversion.cpp DEPENDS TritonGPUTransformsIncGen + TritonGPUCombineIncGen LINK_LIBS PUBLIC TritonIR diff --git a/python/src/triton.cc b/python/src/triton.cc index 4c16bf6d7..3b7c0d7ae 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1344,6 +1344,9 @@ void init_triton_ir(py::module &&m) { .def("add_tritongpu_pipeline_pass", [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUPipelinePass()); }) + .def("add_triton_gpu_combine_pass", [](mlir::PassManager &self) { + self.addPass(mlir::triton::gpu::createCombineOpsPass()); + }) ; } diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 2ed06528a..67c9b8b0c 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1316,6 +1316,7 @@ class JITFunction: pm.add_convert_triton_to_tritongpu_pass() pm.add_tritongpu_pipeline_pass() pm.add_canonicalizer_pass() + pm.add_triton_gpu_combine_pass() pm.run(mod) return mod diff --git a/rewrite-test/jit/matmul/matmul.mlir b/rewrite-test/jit/matmul/matmul.mlir index a93754229..fd2268ead 100644 --- a/rewrite-test/jit/matmul/matmul.mlir +++ b/rewrite-test/jit/matmul/matmul.mlir @@ -187,68 +187,64 @@ module { %45 = tt.getelementptr %44, %43, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> %46 = tt.broadcast %cst : (f32) -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> %47 = arith.index_cast %arg5 : i32 to index - %48 = tt.load %34, %cst_0, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> - %49 = tt.load %45, %cst_0, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> - %50 = "triton_gpu.convert_layout"(%48) : (tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> - %51 = "triton_gpu.convert_layout"(%49) : (tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> - %52 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %53 = tt.getelementptr %34, %52, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %54 = arith.muli %arg7, %c128_i32 : i32 - %55 = tt.broadcast %54 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %56 = tt.getelementptr %45, %55, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %57:8 = scf.for %arg9 = %c0 to %47 step %c128 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45, %arg13 = %50, %arg14 = %51, %arg15 = %56, %arg16 = %53, %arg17 = %c0) -> (tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, index) { - %87 = tt.dot %arg13, %arg14, %arg10 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> * tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> - %88 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %89 = tt.getelementptr %arg11, %88, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %90 = arith.muli %arg7, %c128_i32 : i32 - %91 = tt.broadcast %90 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %92 = tt.getelementptr %arg12, %91, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %93 = arith.addi %arg17, %c128 : index - %94 = arith.cmpi slt, %93, %47 : index - %95 = tt.broadcast %94 : (i1) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %96 = tt.load %arg16, %95, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> - %97 = tt.broadcast %94 : (i1) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %98 = arith.andi %97, %95 : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %99 = tt.load %arg15, %98, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> - %100 = "triton_gpu.convert_layout"(%96) : (tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> - %101 = "triton_gpu.convert_layout"(%99) : (tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> - %102 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %103 = tt.getelementptr %arg16, %102, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %104 = arith.muli %arg7, %c128_i32 : i32 - %105 = tt.broadcast %104 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %106 = tt.getelementptr %arg15, %105, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - scf.yield %87, %89, %92, %100, %101, %106, %103, %93 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, index + %48 = "triton_gpu.copy_async"(%34, %cst_0, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> + %49 = "triton_gpu.copy_async"(%45, %cst_0, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> + %50 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %51 = tt.getelementptr %34, %50, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %52 = arith.muli %arg7, %c128_i32 : i32 + %53 = tt.broadcast %52 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %54 = tt.getelementptr %45, %53, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %55:8 = scf.for %arg9 = %c0 to %47 step %c128 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45, %arg13 = %48, %arg14 = %49, %arg15 = %51, %arg16 = %54, %arg17 = %c0) -> (tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, index) { + %85 = tt.dot %arg13, %arg14, %arg10 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> * tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> + %86 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %87 = tt.getelementptr %arg11, %86, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %88 = arith.muli %arg7, %c128_i32 : i32 + %89 = tt.broadcast %88 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %90 = tt.getelementptr %arg12, %89, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %91 = arith.addi %arg17, %c128 : index + %92 = arith.cmpi slt, %91, %47 : index + %93 = tt.broadcast %92 : (i1) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %94 = tt.broadcast %92 : (i1) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %95 = arith.andi %94, %93 : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %96 = "triton_gpu.copy_async"(%arg15, %93, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> + %97 = "triton_gpu.copy_async"(%arg16, %95, %cst_1) {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : (tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xi1, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">> + %98 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %99 = tt.getelementptr %arg15, %98, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %100 = arith.muli %arg7, %c128_i32 : i32 + %101 = tt.broadcast %100 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %102 = tt.getelementptr %arg16, %101, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + scf.yield %85, %87, %90, %96, %97, %99, %102, %91 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128xf16, #triton_gpu<"shared (memory) encoding<>">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, index } - %58 = arith.truncf %57#0 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">> to tensor<128x128xf16, #triton_gpu<"coalesced encoding">> - %59 = arith.muli %12, %c128_i32 : i32 - %60 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %61 = tt.broadcast %59 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> - %62 = arith.addi %61, %60 : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %63 = arith.muli %14, %c128_i32 : i32 - %64 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %65 = tt.broadcast %63 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> - %66 = arith.addi %65, %64 : tensor<128xi32, #triton_gpu<"coalesced encoding">> - %67 = tt.reshape %62 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %68 = tt.broadcast %arg8 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %69 = arith.muli %68, %67 : tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %70 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">> - %71 = tt.getelementptr %70, %69, : tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">> - %72 = tt.reshape %66 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %73 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %74 = arith.muli %72, %73 : tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %75 = tt.broadcast %71 : (tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">>) -> tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %76 = tt.broadcast %74 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> - %77 = tt.getelementptr %75, %76, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> - %78 = tt.reshape %62 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %79 = tt.broadcast %arg3 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> - %80 = "triton_gpu.cmpi"(%78, %79) {predicate = 2 : i64} : (tensor<128x1xi32, #triton_gpu<"coalesced encoding">>, tensor<128x1xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi1, #triton_gpu<"coalesced encoding">> - %81 = tt.reshape %66 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %82 = tt.broadcast %arg4 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> - %83 = "triton_gpu.cmpi"(%81, %82) {predicate = 2 : i64} : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>, tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi1, #triton_gpu<"coalesced encoding">> - %84 = tt.broadcast %80 : (tensor<128x1xi1, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %85 = tt.broadcast %83 : (tensor<1x128xi1, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - %86 = arith.andi %84, %85 : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> - tt.store %77, %58, %86, : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %56 = arith.truncf %55#0 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">> to tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %57 = arith.muli %12, %c128_i32 : i32 + %58 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %59 = tt.broadcast %57 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> + %60 = arith.addi %59, %58 : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %61 = arith.muli %14, %c128_i32 : i32 + %62 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %63 = tt.broadcast %61 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> + %64 = arith.addi %63, %62 : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %65 = tt.reshape %60 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %66 = tt.broadcast %arg8 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %67 = arith.muli %66, %65 : tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %68 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">> + %69 = tt.getelementptr %68, %67, : tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">> + %70 = tt.reshape %64 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %71 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %72 = arith.muli %70, %71 : tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %73 = tt.broadcast %69 : (tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">>) -> tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %74 = tt.broadcast %72 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %75 = tt.getelementptr %73, %74, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %76 = tt.reshape %60 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %77 = tt.broadcast %arg3 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %78 = "triton_gpu.cmpi"(%76, %77) {predicate = 2 : i64} : (tensor<128x1xi32, #triton_gpu<"coalesced encoding">>, tensor<128x1xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi1, #triton_gpu<"coalesced encoding">> + %79 = tt.reshape %64 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %80 = tt.broadcast %arg4 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %81 = "triton_gpu.cmpi"(%79, %80) {predicate = 2 : i64} : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>, tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi1, #triton_gpu<"coalesced encoding">> + %82 = tt.broadcast %78 : (tensor<128x1xi1, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %83 = tt.broadcast %81 : (tensor<1x128xi1, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %84 = arith.andi %82, %83 : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + tt.store %75, %56, %84, : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> return } func @"cdiv__i32__1cconstexpr[128]"(%arg0: i32) -> i32 {