From 1a4fbed25bb6f7c9e805f0207f3d043d1b15c02f Mon Sep 17 00:00:00 2001 From: Yan Da Date: Wed, 11 May 2022 16:13:53 +0800 Subject: [PATCH] Skeleton for the pipeline pass --- CMakeLists.txt | 2 + .../triton/Dialect/TritonGPU/CMakeLists.txt | 1 + .../TritonGPU/Transforms/CMakeLists.txt | 3 + .../Dialect/TritonGPU/Transforms/Passes.h | 14 + .../Dialect/TritonGPU/Transforms/Passes.td | 27 ++ .../TritonToTritonGPU/CMakeLists.txt | 2 +- .../TritonGPU/Transforms/CMakeLists.txt | 8 +- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 21 ++ rewrite-test/jit/matmul/matmul.mlir | 278 +++++++++++++++++- 9 files changed, 344 insertions(+), 12 deletions(-) create mode 100644 include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt create mode 100644 include/triton/Dialect/TritonGPU/Transforms/Passes.h create mode 100644 include/triton/Dialect/TritonGPU/Transforms/Passes.td create mode 100644 lib/Dialect/TritonGPU/Transforms/Pipeline.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 7055418c6..c8821fd77 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -182,6 +182,8 @@ target_link_libraries(triton TritonTransforms TritonDriver TritonToTritonGPU + TritonGPUIR + TritonGPUTransforms # optimizations MLIRPass MLIRTransforms diff --git a/include/triton/Dialect/TritonGPU/CMakeLists.txt b/include/triton/Dialect/TritonGPU/CMakeLists.txt index f33061b2d..9f57627c3 100644 --- a/include/triton/Dialect/TritonGPU/CMakeLists.txt +++ b/include/triton/Dialect/TritonGPU/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt b/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..6be94d1a8 --- /dev/null +++ b/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU) +add_public_tablegen_target(TritonGPUTransformsIncGen) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h new file mode 100644 index 000000000..870714504 --- /dev/null +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -0,0 +1,14 @@ +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +std::unique_ptr createTritonGPUPipelinePass(); + +// /// Generate the code for registering passes. +// #define GEN_PASS_REGISTRATION +// #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +} // namespace mlir +#endif diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td new file mode 100644 index 000000000..490df2730 --- /dev/null +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -0,0 +1,27 @@ +#ifndef TRITONGPU_PASSES +#define TRITONGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::FuncOp"> { + let summary = "pipeline"; + + let description = [{ + scf.for() { + %a = load %a_ptr; + %b = load %b_ptr; + + %d = dot %a, %b, %c; + } + + => + + ... + }]; + + let constructor = "mlir::triton::gpu::createPipelinePass"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; +} + +#endif diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index 5c044d026..382ee7a1d 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -15,5 +15,5 @@ add_mlir_conversion_library(TritonToTritonGPU MLIRPass TritonIR TritonGPUIR - TritonGPUConversion + TritonGPUTransforms ) diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 5d089f297..ff3814a87 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -1,10 +1,12 @@ -add_mlir_dialect_library(TritonGPUConversion +add_mlir_dialect_library(TritonGPUTransforms + Pipeline.cpp TritonGPUConversion.cpp - # ADDITIONAL_HEADER_DIRS + DEPENDS + TritonGPUTransformsIncGen LINK_LIBS PUBLIC TritonIR TritonGPUIR - # MLIRTransformUtils + MLIRTransformUtils ) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp new file mode 100644 index 000000000..ecb9f8b9a --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -0,0 +1,21 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { +struct PipelinePass : public TritonGPUPipelineBase { + void runOnOperation() override { + getOperation()->walk([&](scf::ForOp forOp) { + + }); + } +}; +} // anonymous namespace + +std::unique_ptr mlir::createTritonGPUPipelinePass() { + return std::make_unique(); +} diff --git a/rewrite-test/jit/matmul/matmul.mlir b/rewrite-test/jit/matmul/matmul.mlir index b73890fdd..eb240cd28 100644 --- a/rewrite-test/jit/matmul/matmul.mlir +++ b/rewrite-test/jit/matmul/matmul.mlir @@ -1,8 +1,270 @@ -Traceback (most recent call last): - File "matmul.py", line 1, in - import triton - File "/home/da/miniconda3/envs/torch-src/lib/python3.7/site-packages/triton-2.0.0-py3.7-linux-x86_64.egg/triton/__init__.py", line 9, in - from .code_gen import cdiv, next_power_of_2, jit, autotune, heuristics, \ - File "/home/da/miniconda3/envs/torch-src/lib/python3.7/site-packages/triton-2.0.0-py3.7-linux-x86_64.egg/triton/code_gen.py", line 23, in - import triton._C.libtriton.triton as _triton -ImportError: /home/da/miniconda3/envs/torch-src/lib/python3.7/site-packages/triton-2.0.0-py3.7-linux-x86_64.egg/triton/_C/libtriton.so: undefined symbol: _ZN4mlir6triton5CatOp10getEffectsERN4llvm15SmallVectorImplINS_11SideEffects14EffectInstanceINS_13MemoryEffects6EffectEEEEE +module { + func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c128_13c128_14c128_15c8(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = call @"cdiv__i32__1cconstexpr[128]"(%arg3) : (i32) -> i32 + %2 = call @"cdiv__i32__1cconstexpr[128]"(%arg4) : (i32) -> i32 + %c8_i32 = arith.constant 8 : i32 + %3 = arith.muli %2, %c8_i32 : i32 + %4 = arith.divsi %0, %3 : i32 + %c8_i32_0 = arith.constant 8 : i32 + %5 = arith.muli %4, %c8_i32_0 : i32 + %6 = arith.subi %1, %5 : i32 + %7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32 + %8 = arith.remsi %0, %7 : i32 + %9 = arith.addi %5, %8 : i32 + %10 = arith.remsi %0, %3 : i32 + %11 = arith.divsi %10, %7 : i32 + %c128_i32 = arith.constant 128 : i32 + %12 = arith.muli %9, %c128_i32 : i32 + %13 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %14 = tt.broadcast %12 : (i32) -> tensor<128xi32> + %15 = arith.addi %14, %13 : tensor<128xi32> + %c128_i32_1 = arith.constant 128 : i32 + %16 = arith.muli %11, %c128_i32_1 : i32 + %17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %18 = tt.broadcast %16 : (i32) -> tensor<128xi32> + %19 = arith.addi %18, %17 : tensor<128xi32> + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %21 = tt.reshape %15 : (tensor<128xi32>) -> tensor<128x1xi32> + %22 = tt.broadcast %arg6 : (i32) -> tensor<128x1xi32> + %23 = arith.muli %21, %22 : tensor<128x1xi32> + %24 = tt.reshape %20 : (tensor<128xi32>) -> tensor<1x128xi32> + %c1_i32 = arith.constant 1 : i32 + %25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32> + %26 = arith.muli %24, %25 : tensor<1x128xi32> + %27 = tt.broadcast %23 : (tensor<128x1xi32>) -> tensor<128x128xi32> + %28 = tt.broadcast %26 : (tensor<1x128xi32>) -> tensor<128x128xi32> + %29 = arith.addi %27, %28 : tensor<128x128xi32> + %30 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<128x128x!tt.ptr> + %31 = tt.getelementptr %30, %29, : tensor<128x128x!tt.ptr> + %32 = tt.reshape %20 : (tensor<128xi32>) -> tensor<128x1xi32> + %33 = tt.broadcast %arg7 : (i32) -> tensor<128x1xi32> + %34 = arith.muli %32, %33 : tensor<128x1xi32> + %35 = tt.reshape %19 : (tensor<128xi32>) -> tensor<1x128xi32> + %c1_i32_2 = arith.constant 1 : i32 + %36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x128xi32> + %37 = arith.muli %35, %36 : tensor<1x128xi32> + %38 = tt.broadcast %34 : (tensor<128x1xi32>) -> tensor<128x128xi32> + %39 = tt.broadcast %37 : (tensor<1x128xi32>) -> tensor<128x128xi32> + %40 = arith.addi %38, %39 : tensor<128x128xi32> + %41 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<128x128x!tt.ptr> + %42 = tt.getelementptr %41, %40, : tensor<128x128x!tt.ptr> + %cst = arith.constant 0.000000e+00 : f32 + %43 = tt.broadcast %cst : (f32) -> tensor<128x128xf32> + %c0_i32 = arith.constant 0 : i32 + %c128_i32_3 = arith.constant 128 : i32 + %44 = arith.index_cast %c0_i32 : i32 to index + %45 = arith.index_cast %arg5 : i32 to index + %46 = arith.index_cast %c128_i32_3 : i32 to index + %47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<128x128xf32>, tensor<128x128x!tt.ptr>, tensor<128x128x!tt.ptr>) { + %cst_7 = arith.constant dense : tensor<128x128xi1> + %cst_8 = arith.constant dense<0.000000e+00> : tensor<128x128xf16> + %77 = tt.load %arg11, %cst_7, %cst_8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16> + %cst_9 = arith.constant dense : tensor<128x128xi1> + %cst_10 = arith.constant dense<0.000000e+00> : tensor<128x128xf16> + %78 = tt.load %arg12, %cst_9, %cst_10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16> + %cst_11 = arith.constant 0.000000e+00 : f32 + %79 = tt.broadcast %cst_11 : (f32) -> tensor<128x128xf32> + %80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<128x128xf16> * tensor<128x128xf16> -> tensor<128x128xf32> + %81 = arith.addf %arg10, %80 : tensor<128x128xf32> + %c128_i32_12 = arith.constant 128 : i32 + %82 = tt.broadcast %c128_i32_12 : (i32) -> tensor<128x128xi32> + %83 = tt.getelementptr %arg11, %82, : tensor<128x128x!tt.ptr> + %c128_i32_13 = arith.constant 128 : i32 + %84 = arith.muli %arg7, %c128_i32_13 : i32 + %85 = tt.broadcast %84 : (i32) -> tensor<128x128xi32> + %86 = tt.getelementptr %arg12, %85, : tensor<128x128x!tt.ptr> + scf.yield %81, %83, %86 : tensor<128x128xf32>, tensor<128x128x!tt.ptr>, tensor<128x128x!tt.ptr> + } + %48 = arith.truncf %47#0 : tensor<128x128xf32> to tensor<128x128xf16> + %c128_i32_4 = arith.constant 128 : i32 + %49 = arith.muli %9, %c128_i32_4 : i32 + %50 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %51 = tt.broadcast %49 : (i32) -> tensor<128xi32> + %52 = arith.addi %51, %50 : tensor<128xi32> + %c128_i32_5 = arith.constant 128 : i32 + %53 = arith.muli %11, %c128_i32_5 : i32 + %54 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %55 = tt.broadcast %53 : (i32) -> tensor<128xi32> + %56 = arith.addi %55, %54 : tensor<128xi32> + %57 = tt.reshape %52 : (tensor<128xi32>) -> tensor<128x1xi32> + %58 = tt.broadcast %arg8 : (i32) -> tensor<128x1xi32> + %59 = arith.muli %58, %57 : tensor<128x1xi32> + %60 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr> + %61 = tt.getelementptr %60, %59, : tensor<128x1x!tt.ptr> + %62 = tt.reshape %56 : (tensor<128xi32>) -> tensor<1x128xi32> + %c1_i32_6 = arith.constant 1 : i32 + %63 = tt.broadcast %c1_i32_6 : (i32) -> tensor<1x128xi32> + %64 = arith.muli %62, %63 : tensor<1x128xi32> + %65 = tt.broadcast %61 : (tensor<128x1x!tt.ptr>) -> tensor<128x128x!tt.ptr> + %66 = tt.broadcast %64 : (tensor<1x128xi32>) -> tensor<128x128xi32> + %67 = tt.getelementptr %65, %66, : tensor<128x128x!tt.ptr> + %68 = tt.reshape %52 : (tensor<128xi32>) -> tensor<128x1xi32> + %69 = tt.broadcast %arg3 : (i32) -> tensor<128x1xi32> + %70 = arith.cmpi slt, %68, %69 : tensor<128x1xi32> + %71 = tt.reshape %56 : (tensor<128xi32>) -> tensor<1x128xi32> + %72 = tt.broadcast %arg4 : (i32) -> tensor<1x128xi32> + %73 = arith.cmpi slt, %71, %72 : tensor<1x128xi32> + %74 = tt.broadcast %70 : (tensor<128x1xi1>) -> tensor<128x128xi1> + %75 = tt.broadcast %73 : (tensor<1x128xi1>) -> tensor<128x128xi1> + %76 = arith.andi %74, %75 : tensor<128x128xi1> + tt.store %67, %48, %76, : tensor<128x128xf16> + return + } + func @"cdiv__i32__1cconstexpr[128]"(%arg0: i32) -> i32 { + %c128_i32 = arith.constant 128 : i32 + %0 = arith.addi %arg0, %c128_i32 : i32 + %c1_i32 = arith.constant 1 : i32 + %1 = arith.subi %0, %c1_i32 : i32 + %c128_i32_0 = arith.constant 128 : i32 + %2 = arith.divsi %1, %c128_i32_0 : i32 + return %2 : i32 + } + func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 { + %c8_i32 = arith.constant 8 : i32 + %0 = arith.cmpi slt, %arg0, %c8_i32 : i32 + %c8_i32_0 = arith.constant 8 : i32 + %1 = select %0, %arg0, %c8_i32_0 : i32 + return %1 : i32 + } +} +is yield legal? +scf.yield %80, %82, %85 : tensor<128x128xf32>, tensor<128x128x!tt.ptr>, tensor<128x128x!tt.ptr> +is legal: 0 +converting for...: +%78:3 = scf.for %arg9 = %c0 to %77 step %c128 iter_args(%arg10 = %75, %arg11 = %51, %arg12 = %73) -> (tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>) { + %109 = tt.load <>, %cst_2, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16> + %110 = tt.load <>, %cst_2, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16> + %111 = tt.dot %109, %110, <> {allowTF32 = true} : tensor<128x128xf16> * tensor<128x128xf16> -> tensor<128x128xf32> + %112 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32> + %113 = tt.getelementptr <>, %112, : tensor<128x128x!tt.ptr> + %114 = arith.muli %arg7, %c128_i32 : i32 + %115 = tt.broadcast %114 : (i32) -> tensor<128x128xi32> + %116 = tt.getelementptr <>, %115, : tensor<128x128x!tt.ptr> + scf.yield %111, %113, %116 : tensor<128x128xf32>, tensor<128x128x!tt.ptr>, tensor<128x128x!tt.ptr> +} +converting dot... +%113 = tt.dot %109, %111, %arg10 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> * tensor<128x128xf16, #triton_gpu<"coalesced encoding">> -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> +is yield legal? +scf.yield %114, %118, %123 : tensor<128x128xf32>, tensor<128x128x!tt.ptr>, tensor<128x128x!tt.ptr> +is legal: 0 +converting yield.... +is yield legal? +scf.yield %113, %117, %122 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> +is legal: 1 +module { + func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c128_13c128_14c128_15c8(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %cst_0 = arith.constant dense : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %cst_1 = arith.constant 0.000000e+00 : f32 + %c128_i32 = arith.constant 128 : i32 + %c127_i32 = arith.constant 127 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c127_i32 : i32 + %4 = arith.divsi %3, %c128_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.cmpi slt, %8, %c8_i32 : i32 + %10 = select %9, %8, %c8_i32 : i32 + %11 = arith.remsi %0, %10 : i32 + %12 = arith.addi %7, %11 : i32 + %13 = arith.remsi %0, %5 : i32 + %14 = arith.divsi %13, %10 : i32 + %15 = arith.muli %12, %c128_i32 : i32 + %16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %17 = tt.broadcast %15 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> + %18 = arith.addi %17, %16 : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %19 = arith.muli %14, %c128_i32 : i32 + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %21 = tt.broadcast %19 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> + %22 = arith.addi %21, %20 : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %23 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %24 = tt.reshape %18 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %25 = tt.broadcast %arg6 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %26 = arith.muli %24, %25 : tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %27 = tt.reshape %23 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %29 = arith.muli %27, %28 : tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %30 = tt.broadcast %26 : (tensor<128x1xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %31 = tt.broadcast %29 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %32 = arith.addi %30, %31 : tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %33 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %34 = tt.getelementptr %33, %32, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %35 = tt.reshape %23 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %36 = tt.broadcast %arg7 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %37 = arith.muli %35, %36 : tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %38 = tt.reshape %22 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %40 = arith.muli %38, %39 : tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %41 = tt.broadcast %37 : (tensor<128x1xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %42 = tt.broadcast %40 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %43 = arith.addi %41, %42 : tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %44 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %45 = tt.getelementptr %44, %43, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %46 = tt.broadcast %cst_1 : (f32) -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> + %47 = arith.index_cast %arg5 : i32 to index + %48:3 = scf.for %arg9 = %c0 to %47 step %c128 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45) -> (tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>) { + %78 = tt.load %arg11, %cst_0, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %79 = tt.load %arg12, %cst_0, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %80 = tt.dot %78, %79, %arg10 {allowTF32 = true} : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> * tensor<128x128xf16, #triton_gpu<"coalesced encoding">> -> tensor<128x128xf32, #triton_gpu<"coalesced encoding">> + %81 = tt.broadcast %c128_i32 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %82 = tt.getelementptr %arg11, %81, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %83 = arith.muli %arg7, %c128_i32 : i32 + %84 = tt.broadcast %83 : (i32) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %85 = tt.getelementptr %arg12, %84, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + scf.yield %80, %82, %85 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">>, tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + } + %49 = arith.truncf %48#0 : tensor<128x128xf32, #triton_gpu<"coalesced encoding">> to tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + %50 = arith.muli %12, %c128_i32 : i32 + %51 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %52 = tt.broadcast %50 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> + %53 = arith.addi %52, %51 : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %54 = arith.muli %14, %c128_i32 : i32 + %55 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %56 = tt.broadcast %54 : (i32) -> tensor<128xi32, #triton_gpu<"coalesced encoding">> + %57 = arith.addi %56, %55 : tensor<128xi32, #triton_gpu<"coalesced encoding">> + %58 = tt.reshape %53 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %59 = tt.broadcast %arg8 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %60 = arith.muli %59, %58 : tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %61 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">> + %62 = tt.getelementptr %61, %60, : tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">> + %63 = tt.reshape %57 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %64 = tt.broadcast %c1_i32 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %65 = arith.muli %63, %64 : tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %66 = tt.broadcast %62 : (tensor<128x1x!tt.ptr, #triton_gpu<"coalesced encoding">>) -> tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %67 = tt.broadcast %65 : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi32, #triton_gpu<"coalesced encoding">> + %68 = tt.getelementptr %66, %67, : tensor<128x128x!tt.ptr, #triton_gpu<"coalesced encoding">> + %69 = tt.reshape %53 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %70 = tt.broadcast %arg3 : (i32) -> tensor<128x1xi32, #triton_gpu<"coalesced encoding">> + %71 = "triton_gpu.cmpi"(%69, %70) {predicate = 2 : i64} : (tensor<128x1xi32, #triton_gpu<"coalesced encoding">>, tensor<128x1xi32, #triton_gpu<"coalesced encoding">>) -> tensor<128x1xi1, #triton_gpu<"coalesced encoding">> + %72 = tt.reshape %57 : (tensor<128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %73 = tt.broadcast %arg4 : (i32) -> tensor<1x128xi32, #triton_gpu<"coalesced encoding">> + %74 = "triton_gpu.cmpi"(%72, %73) {predicate = 2 : i64} : (tensor<1x128xi32, #triton_gpu<"coalesced encoding">>, tensor<1x128xi32, #triton_gpu<"coalesced encoding">>) -> tensor<1x128xi1, #triton_gpu<"coalesced encoding">> + %75 = tt.broadcast %71 : (tensor<128x1xi1, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %76 = tt.broadcast %74 : (tensor<1x128xi1, #triton_gpu<"coalesced encoding">>) -> tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + %77 = arith.andi %75, %76 : tensor<128x128xi1, #triton_gpu<"coalesced encoding">> + tt.store %68, %49, %77, : tensor<128x128xf16, #triton_gpu<"coalesced encoding">> + return + } + func @"cdiv__i32__1cconstexpr[128]"(%arg0: i32) -> i32 { + %c127_i32 = arith.constant 127 : i32 + %c128_i32 = arith.constant 128 : i32 + %0 = arith.addi %arg0, %c127_i32 : i32 + %1 = arith.divsi %0, %c128_i32 : i32 + return %1 : i32 + } + func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 { + %c8_i32 = arith.constant 8 : i32 + %0 = arith.cmpi slt, %arg0, %c8_i32 : i32 + %1 = select %0, %arg0, %c8_i32 : i32 + return %1 : i32 + } +}