From 74585fb9707c39241f90e3f521e01a2b0d08eee0 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Wed, 27 Apr 2022 13:45:56 +0800 Subject: [PATCH] Add Triton CombineOps --- include/triton/ir/TritonOps.td | 6 + include/triton/transforms/Passes.h | 12 + lib/CMakeLists.txt | 1 + lib/codegen/pass.cc | 16 +- lib/ir/Ops.cpp | 13 + lib/transforms/CombineOps.cpp | 43 ++++ python/src/triton.cc | 9 +- rewrite-test/jit/matmul/matmul.mlir | 360 +++++++++++++++++++--------- rewrite-test/jit/matmul/matmul.py | 13 +- 9 files changed, 339 insertions(+), 134 deletions(-) create mode 100644 include/triton/transforms/Passes.h create mode 100644 lib/transforms/CombineOps.cpp diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td index b96ae5116..f664fba9a 100644 --- a/include/triton/ir/TritonOps.td +++ b/include/triton/ir/TritonOps.td @@ -131,6 +131,7 @@ def TT_EvictionPolicyAttr : I32EnumAttr< def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape, + MemoryEffects<[MemRead]>, TypesMatchWith<"infer ptr type from result type", "result", "ptr", "getPointerTypeFromTensor($_self)">, @@ -161,6 +162,7 @@ def TT_LoadOp : TT_Op<"load", def TT_StoreOp : TT_Op<"store", [SameOperandsShape, + MemoryEffects<[MemWrite]>, TypesMatchWith<"infer ptr type from value type", "value", "ptr", "getPointerTypeFromTensor($_self)">, @@ -214,6 +216,8 @@ def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> { let results = (outs TT_Type:$result); let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; + + let hasFolder = 1; } def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> { @@ -259,6 +263,8 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect, let results = (outs TT_FpIntTensor:$d); let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)"; + + // let hasCanonicalizer = 1; } // reduction diff --git a/include/triton/transforms/Passes.h b/include/triton/transforms/Passes.h new file mode 100644 index 000000000..69b1593a3 --- /dev/null +++ b/include/triton/transforms/Passes.h @@ -0,0 +1,12 @@ +#ifndef TRITON_TRANSFORMS_PASSES_H_ +#define TRITON_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +std::unique_ptr createCombineOpsPass(); + +// // Registration +// #define GEN_PASS_REGISTRATION +// #include + +#endif // TRITON_TRANSFORMS_PASSES_H_ diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index f50dbb34e..eade8021b 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,3 +1,4 @@ # add_subdirectory(codegen) add_subdirectory(driver) add_subdirectory(ir) +# add_subdirectory(transforms) diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 8921d6c84..5c93e10e6 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -55,16 +55,12 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC dce.run(ir); disassociate.run(ir); dce.run(ir); - align.run(ir); - axes.run(ir); - layouts.run(ir); + align.run(ir); axes.run(ir); layouts.run(ir); peephole.run(ir); dce.run(ir); if (target->is_gpu()) cts.run(ir); - align.run(ir); - axes.run(ir); - layouts.run(ir); + align.run(ir); axes.run(ir); layouts.run(ir); coalesce.run(ir); dce.run(ir); align.run(ir); @@ -72,14 +68,10 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC if (target->is_gpu()) cts.run(ir); dce.run(ir); - align.run(ir); - axes.run(ir); - layouts.run(ir); + align.run(ir); axes.run(ir); layouts.run(ir); peephole.run(ir); dce.run(ir); - align.run(ir); - axes.run(ir); - layouts.run(ir); + align.run(ir); axes.run(ir); layouts.run(ir); swizzle.run(ir); liveness.run(ir); allocation.run(ir); diff --git a/lib/ir/Ops.cpp b/lib/ir/Ops.cpp index 32d9f53ee..41975bcd4 100644 --- a/lib/ir/Ops.cpp +++ b/lib/ir/Ops.cpp @@ -95,5 +95,18 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, :: state.addTypes({resultType}); } +//-- DotOp -- + +//-- BroadcastOp -- +OpFoldResult BroadcastOp::fold(ArrayRef operands) { + auto constOperand = src().getDefiningOp(); + if (!constOperand) + return {}; + + auto shapedType = getType().cast(); + + return SplatElementsAttr::get(shapedType, {constOperand.getValue()}); +} + } // namespace triton } // namespace mlir diff --git a/lib/transforms/CombineOps.cpp b/lib/transforms/CombineOps.cpp new file mode 100644 index 000000000..014d035bc --- /dev/null +++ b/lib/transforms/CombineOps.cpp @@ -0,0 +1,43 @@ +#include "triton/transforms/Passes.h" +#include + +using namespace mlir; + +namespace { +// +struct CombineDotOp : public RewritePattern { + explicit CombineDotOp(MLIRContext *context) + : RewritePattern(/*rootName*/FuncOp::getOperationName(), /*Benefit*/1, context); + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // + + } +}; +// + +/// Passes.td (?) +struct CombineOpsPass { // : public mlir::OperationPass + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + RewritePatternSet patterns(context); + + patterns.add(); + patterns.add(); + patterns.add(); + patterns.add(); + + // TODO: populate xxx Patter(?) + + // TODO: should be use applyPartialConversion(...) ? + if (failed(applyFullConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); + }; +}; +} // anonymous namespace + +std::unique_ptr mlir::triton::createCombineOpsPass() { + return std::make_unique(); +} diff --git a/python/src/triton.cc b/python/src/triton.cc index fe56340e4..c6c7b3862 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1254,7 +1254,7 @@ void init_triton_ir(py::module &&m) { .def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector &shape) -> mlir::Value { auto loc = self.getUnknownLoc(); if (auto argType = arg.getType().dyn_cast()) - return self.create( + return self.createOrFold( loc, mlir::RankedTensorType::get(shape, argType.getElementType()), arg ); throw std::runtime_error("arg is not of RankedTensorType, use create_splat"); @@ -1323,12 +1323,15 @@ void init_triton_ir(py::module &&m) { py::class_(m, "pass_manager") .def(py::init()) - .def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) { - self.run(mod.getOperation()); + .def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool { + return mlir::succeeded(self.run(mod.getOperation())); }) .def("add_inliner_pass", [](mlir::PassManager &self) { self.addPass(mlir::createInlinerPass()); }) + .def("add_canonicalizer_pass", [](mlir::PassManager &self) { + self.addPass(mlir::createCanonicalizerPass()); + }) ; } diff --git a/rewrite-test/jit/matmul/matmul.mlir b/rewrite-test/jit/matmul/matmul.mlir index 87a81abb7..8ce9154bf 100644 --- a/rewrite-test/jit/matmul/matmul.mlir +++ b/rewrite-test/jit/matmul/matmul.mlir @@ -1,127 +1,251 @@ module { - func @matmul_kernel(%arg0: !triton.ptr, %arg1: !triton.ptr, %arg2: !triton.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { - %0 = triton.get_program_id {axis = 0 : i32} : i32 - %c64_i32 = arith.constant 64 : i32 - %1 = arith.addi %arg3, %c64_i32 : i32 - %c1_i32 = arith.constant 1 : i32 - %2 = arith.subi %1, %c1_i32 : i32 - %c64_i32_0 = arith.constant 64 : i32 - %3 = arith.divsi %2, %c64_i32_0 : i32 - %c64_i32_1 = arith.constant 64 : i32 - %4 = arith.addi %arg4, %c64_i32_1 : i32 - %c1_i32_2 = arith.constant 1 : i32 - %5 = arith.subi %4, %c1_i32_2 : i32 - %c64_i32_3 = arith.constant 64 : i32 - %6 = arith.divsi %5, %c64_i32_3 : i32 + func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = call @"cdiv__i32__1cconstexpr[64]"(%arg3) : (i32) -> i32 + %2 = call @"cdiv__i32__1cconstexpr[64]"(%arg4) : (i32) -> i32 %c8_i32 = arith.constant 8 : i32 - %7 = arith.muli %6, %c8_i32 : i32 - %8 = arith.divsi %0, %7 : i32 - %c8_i32_4 = arith.constant 8 : i32 - %9 = arith.muli %8, %c8_i32_4 : i32 - %10 = arith.subi %3, %9 : i32 - %c8_i32_5 = arith.constant 8 : i32 - %11 = arith.cmpi slt, %10, %c8_i32_5 : i32 - %c8_i32_6 = arith.constant 8 : i32 - %12 = select %11, %10, %c8_i32_6 : i32 - %13 = arith.remsi %0, %12 : i32 - %14 = arith.addi %9, %13 : i32 - %15 = arith.remsi %0, %7 : i32 - %16 = arith.divsi %15, %12 : i32 - %c64_i32_7 = arith.constant 64 : i32 - %17 = arith.muli %14, %c64_i32_7 : i32 - %18 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %19 = triton.broadcast %17 : (i32) -> tensor<64xi32> - %20 = arith.addi %19, %18 : tensor<64xi32> - %c64_i32_8 = arith.constant 64 : i32 - %21 = arith.muli %16, %c64_i32_8 : i32 - %22 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %23 = triton.broadcast %21 : (i32) -> tensor<64xi32> - %24 = arith.addi %23, %22 : tensor<64xi32> - %25 = triton.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %26 = triton.reshape %20 : (tensor<64xi32>) -> tensor<64x1xi32> - %27 = triton.broadcast %arg6 : (i32) -> tensor<64x1xi32> - %28 = arith.muli %26, %27 : tensor<64x1xi32> - %29 = triton.reshape %25 : (tensor<32xi32>) -> tensor<1x32xi32> - %c1_i32_9 = arith.constant 1 : i32 - %30 = triton.broadcast %c1_i32_9 : (i32) -> tensor<1x32xi32> - %31 = arith.muli %29, %30 : tensor<1x32xi32> - %32 = triton.broadcast %28 : (tensor<64x1xi32>) -> tensor<64x32xi32> - %33 = triton.broadcast %31 : (tensor<1x32xi32>) -> tensor<64x32xi32> - %34 = arith.addi %32, %33 : tensor<64x32xi32> - %35 = triton.broadcast %arg0 : (!triton.ptr) -> tensor<64x32x!triton.ptr> - %36 = triton.getelementptr %35, %34, : tensor<64x32x!triton.ptr> - %37 = triton.reshape %25 : (tensor<32xi32>) -> tensor<32x1xi32> - %38 = triton.broadcast %arg7 : (i32) -> tensor<32x1xi32> - %39 = arith.muli %37, %38 : tensor<32x1xi32> - %40 = triton.reshape %24 : (tensor<64xi32>) -> tensor<1x64xi32> - %c1_i32_10 = arith.constant 1 : i32 - %41 = triton.broadcast %c1_i32_10 : (i32) -> tensor<1x64xi32> - %42 = arith.muli %40, %41 : tensor<1x64xi32> - %43 = triton.broadcast %39 : (tensor<32x1xi32>) -> tensor<32x64xi32> - %44 = triton.broadcast %42 : (tensor<1x64xi32>) -> tensor<32x64xi32> - %45 = arith.addi %43, %44 : tensor<32x64xi32> - %46 = triton.broadcast %arg1 : (!triton.ptr) -> tensor<32x64x!triton.ptr> - %47 = triton.getelementptr %46, %45, : tensor<32x64x!triton.ptr> + %3 = arith.muli %2, %c8_i32 : i32 + %4 = arith.divsi %0, %3 : i32 + %c8_i32_0 = arith.constant 8 : i32 + %5 = arith.muli %4, %c8_i32_0 : i32 + %6 = arith.subi %1, %5 : i32 + %7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32 + %8 = arith.remsi %0, %7 : i32 + %9 = arith.addi %5, %8 : i32 + %10 = arith.remsi %0, %3 : i32 + %11 = arith.divsi %10, %7 : i32 + %c64_i32 = arith.constant 64 : i32 + %12 = arith.muli %9, %c64_i32 : i32 + %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %14 = tt.broadcast %12 : (i32) -> tensor<64xi32> + %15 = arith.addi %14, %13 : tensor<64xi32> + %c64_i32_1 = arith.constant 64 : i32 + %16 = arith.muli %11, %c64_i32_1 : i32 + %17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %18 = tt.broadcast %16 : (i32) -> tensor<64xi32> + %19 = arith.addi %18, %17 : tensor<64xi32> + %20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %21 = tt.reshape %15 : (tensor<64xi32>) -> tensor<64x1xi32> + %22 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32> + %23 = arith.muli %21, %22 : tensor<64x1xi32> + %24 = tt.reshape %20 : (tensor<32xi32>) -> tensor<1x32xi32> + %c1_i32 = arith.constant 1 : i32 + %25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32> + %26 = arith.muli %24, %25 : tensor<1x32xi32> + %27 = tt.broadcast %23 : (tensor<64x1xi32>) -> tensor<64x32xi32> + %28 = tt.broadcast %26 : (tensor<1x32xi32>) -> tensor<64x32xi32> + %29 = arith.addi %27, %28 : tensor<64x32xi32> + %30 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<64x32x!tt.ptr> + %31 = tt.getelementptr %30, %29, : tensor<64x32x!tt.ptr> + %32 = tt.reshape %20 : (tensor<32xi32>) -> tensor<32x1xi32> + %33 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32> + %34 = arith.muli %32, %33 : tensor<32x1xi32> + %35 = tt.reshape %19 : (tensor<64xi32>) -> tensor<1x64xi32> + %c1_i32_2 = arith.constant 1 : i32 + %36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x64xi32> + %37 = arith.muli %35, %36 : tensor<1x64xi32> + %38 = tt.broadcast %34 : (tensor<32x1xi32>) -> tensor<32x64xi32> + %39 = tt.broadcast %37 : (tensor<1x64xi32>) -> tensor<32x64xi32> + %40 = arith.addi %38, %39 : tensor<32x64xi32> + %41 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<32x64x!tt.ptr> + %42 = tt.getelementptr %41, %40, : tensor<32x64x!tt.ptr> %cst = arith.constant 0.000000e+00 : f32 - %48 = triton.broadcast %cst : (f32) -> tensor<64x64xf32> + %43 = tt.broadcast %cst : (f32) -> tensor<64x64xf32> %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 - %49 = arith.index_cast %c0_i32 : i32 to index - %50 = arith.index_cast %arg5 : i32 to index - %51 = arith.index_cast %c32_i32 : i32 to index - %52:3 = scf.for %arg9 = %49 to %50 step %51 iter_args(%arg10 = %48, %arg11 = %36, %arg12 = %47) -> (tensor<64x64xf32>, tensor<64x32x!triton.ptr>, tensor<32x64x!triton.ptr>) { - %cst_14 = arith.constant dense : tensor<64x32xi1> - %cst_15 = arith.constant dense<0.000000e+00> : tensor<64x32xf16> - %82 = triton.load %arg11, %cst_14, %cst_15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16> - %cst_16 = arith.constant dense : tensor<32x64xi1> - %cst_17 = arith.constant dense<0.000000e+00> : tensor<32x64xf16> - %83 = triton.load %arg12, %cst_16, %cst_17 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16> - %cst_18 = arith.constant 0.000000e+00 : f32 - %84 = triton.broadcast %cst_18 : (f32) -> tensor<64x64xf32> - %85 = triton.dot %82, %83, %84 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> - %86 = arith.addf %arg10, %85 : tensor<64x64xf32> - %c32_i32_19 = arith.constant 32 : i32 - %87 = triton.broadcast %c32_i32_19 : (i32) -> tensor<64x32xi32> - %88 = triton.getelementptr %arg11, %87, : tensor<64x32x!triton.ptr> - %c32_i32_20 = arith.constant 32 : i32 - %89 = arith.muli %arg7, %c32_i32_20 : i32 - %90 = triton.broadcast %89 : (i32) -> tensor<32x64xi32> - %91 = triton.getelementptr %arg12, %90, : tensor<32x64x!triton.ptr> - scf.yield %86, %88, %91 : tensor<64x64xf32>, tensor<64x32x!triton.ptr>, tensor<32x64x!triton.ptr> + %44 = arith.index_cast %c0_i32 : i32 to index + %45 = arith.index_cast %arg5 : i32 to index + %46 = arith.index_cast %c32_i32 : i32 to index + %47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr>) { + %cst_6 = arith.constant dense : tensor<64x32xi1> + %cst_7 = arith.constant dense<0.000000e+00> : tensor<64x32xf16> + %77 = tt.load %arg11, %cst_6, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16> + %cst_8 = arith.constant dense : tensor<32x64xi1> + %cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf16> + %78 = tt.load %arg12, %cst_8, %cst_9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16> + %cst_10 = arith.constant 0.000000e+00 : f32 + %79 = tt.broadcast %cst_10 : (f32) -> tensor<64x64xf32> + %80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> + %81 = arith.addf %arg10, %80 : tensor<64x64xf32> + %c32_i32_11 = arith.constant 32 : i32 + %82 = tt.broadcast %c32_i32_11 : (i32) -> tensor<64x32xi32> + %83 = tt.getelementptr %arg11, %82, : tensor<64x32x!tt.ptr> + %c32_i32_12 = arith.constant 32 : i32 + %84 = arith.muli %arg7, %c32_i32_12 : i32 + %85 = tt.broadcast %84 : (i32) -> tensor<32x64xi32> + %86 = tt.getelementptr %arg12, %85, : tensor<32x64x!tt.ptr> + scf.yield %81, %83, %86 : tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr> } - %53 = arith.truncf %52#0 : tensor<64x64xf32> to tensor<64x64xf16> - %c64_i32_11 = arith.constant 64 : i32 - %54 = arith.muli %14, %c64_i32_11 : i32 - %55 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %56 = triton.broadcast %54 : (i32) -> tensor<64xi32> - %57 = arith.addi %56, %55 : tensor<64xi32> - %c64_i32_12 = arith.constant 64 : i32 - %58 = arith.muli %16, %c64_i32_12 : i32 - %59 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %60 = triton.broadcast %58 : (i32) -> tensor<64xi32> - %61 = arith.addi %60, %59 : tensor<64xi32> - %62 = triton.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32> - %63 = triton.broadcast %arg8 : (i32) -> tensor<64x1xi32> - %64 = arith.muli %63, %62 : tensor<64x1xi32> - %65 = triton.broadcast %arg2 : (!triton.ptr) -> tensor<64x1x!triton.ptr> - %66 = triton.getelementptr %65, %64, : tensor<64x1x!triton.ptr> - %67 = triton.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32> - %c1_i32_13 = arith.constant 1 : i32 - %68 = triton.broadcast %c1_i32_13 : (i32) -> tensor<1x64xi32> - %69 = arith.muli %67, %68 : tensor<1x64xi32> - %70 = triton.broadcast %66 : (tensor<64x1x!triton.ptr>) -> tensor<64x64x!triton.ptr> - %71 = triton.broadcast %69 : (tensor<1x64xi32>) -> tensor<64x64xi32> - %72 = triton.getelementptr %70, %71, : tensor<64x64x!triton.ptr> - %73 = triton.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32> - %74 = triton.broadcast %arg3 : (i32) -> tensor<64x1xi32> - %75 = arith.cmpi slt, %73, %74 : tensor<64x1xi32> - %76 = triton.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32> - %77 = triton.broadcast %arg4 : (i32) -> tensor<1x64xi32> - %78 = arith.cmpi slt, %76, %77 : tensor<1x64xi32> - %79 = triton.broadcast %75 : (tensor<64x1xi1>) -> tensor<64x64xi1> - %80 = triton.broadcast %78 : (tensor<1x64xi1>) -> tensor<64x64xi1> - %81 = arith.andi %79, %80 : tensor<64x64xi1> - triton.store %72, %53, %81, : tensor<64x64xf16> + %48 = arith.truncf %47#0 : tensor<64x64xf32> to tensor<64x64xf16> + %c64_i32_3 = arith.constant 64 : i32 + %49 = arith.muli %9, %c64_i32_3 : i32 + %50 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %51 = tt.broadcast %49 : (i32) -> tensor<64xi32> + %52 = arith.addi %51, %50 : tensor<64xi32> + %c64_i32_4 = arith.constant 64 : i32 + %53 = arith.muli %11, %c64_i32_4 : i32 + %54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %55 = tt.broadcast %53 : (i32) -> tensor<64xi32> + %56 = arith.addi %55, %54 : tensor<64xi32> + %57 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32> + %58 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32> + %59 = arith.muli %58, %57 : tensor<64x1xi32> + %60 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr> + %61 = tt.getelementptr %60, %59, : tensor<64x1x!tt.ptr> + %62 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32> + %c1_i32_5 = arith.constant 1 : i32 + %63 = tt.broadcast %c1_i32_5 : (i32) -> tensor<1x64xi32> + %64 = arith.muli %62, %63 : tensor<1x64xi32> + %65 = tt.broadcast %61 : (tensor<64x1x!tt.ptr>) -> tensor<64x64x!tt.ptr> + %66 = tt.broadcast %64 : (tensor<1x64xi32>) -> tensor<64x64xi32> + %67 = tt.getelementptr %65, %66, : tensor<64x64x!tt.ptr> + %68 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32> + %69 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32> + %70 = arith.cmpi slt, %68, %69 : tensor<64x1xi32> + %71 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32> + %72 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32> + %73 = arith.cmpi slt, %71, %72 : tensor<1x64xi32> + %74 = tt.broadcast %70 : (tensor<64x1xi1>) -> tensor<64x64xi1> + %75 = tt.broadcast %73 : (tensor<1x64xi1>) -> tensor<64x64xi1> + %76 = arith.andi %74, %75 : tensor<64x64xi1> + tt.store %67, %48, %76, : tensor<64x64xf16> return } + func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 { + %c64_i32 = arith.constant 64 : i32 + %0 = arith.addi %arg0, %c64_i32 : i32 + %c1_i32 = arith.constant 1 : i32 + %1 = arith.subi %0, %c1_i32 : i32 + %c64_i32_0 = arith.constant 64 : i32 + %2 = arith.divsi %1, %c64_i32_0 : i32 + return %2 : i32 + } + func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 { + %c8_i32 = arith.constant 8 : i32 + %0 = arith.cmpi slt, %arg0, %c8_i32 : i32 + %c8_i32_0 = arith.constant 8 : i32 + %1 = select %0, %arg0, %c8_i32_0 : i32 + return %1 : i32 + } +} +module { + func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { + %c1_i32 = arith.constant 1 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<32x64xf16> + %cst_0 = arith.constant dense : tensor<32x64xi1> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf16> + %cst_2 = arith.constant dense : tensor<64x32xi1> + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %cst_3 = arith.constant 0.000000e+00 : f32 + %c64_i32 = arith.constant 64 : i32 + %c63_i32 = arith.constant 63 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.addi %arg3, %c63_i32 : i32 + %2 = arith.divsi %1, %c64_i32 : i32 + %3 = arith.addi %arg4, %c63_i32 : i32 + %4 = arith.divsi %3, %c64_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.cmpi slt, %8, %c8_i32 : i32 + %10 = select %9, %8, %c8_i32 : i32 + %11 = arith.remsi %0, %10 : i32 + %12 = arith.addi %7, %11 : i32 + %13 = arith.remsi %0, %5 : i32 + %14 = arith.divsi %13, %10 : i32 + %15 = arith.muli %12, %c64_i32 : i32 + %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %17 = tt.broadcast %15 : (i32) -> tensor<64xi32> + %18 = arith.addi %17, %16 : tensor<64xi32> + %19 = arith.muli %14, %c64_i32 : i32 + %20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %21 = tt.broadcast %19 : (i32) -> tensor<64xi32> + %22 = arith.addi %21, %20 : tensor<64xi32> + %23 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %24 = tt.reshape %18 : (tensor<64xi32>) -> tensor<64x1xi32> + %25 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32> + %26 = arith.muli %24, %25 : tensor<64x1xi32> + %27 = tt.reshape %23 : (tensor<32xi32>) -> tensor<1x32xi32> + %28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32> + %29 = arith.muli %27, %28 : tensor<1x32xi32> + %30 = tt.broadcast %26 : (tensor<64x1xi32>) -> tensor<64x32xi32> + %31 = tt.broadcast %29 : (tensor<1x32xi32>) -> tensor<64x32xi32> + %32 = arith.addi %30, %31 : tensor<64x32xi32> + %33 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<64x32x!tt.ptr> + %34 = tt.getelementptr %33, %32, : tensor<64x32x!tt.ptr> + %35 = tt.reshape %23 : (tensor<32xi32>) -> tensor<32x1xi32> + %36 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32> + %37 = arith.muli %35, %36 : tensor<32x1xi32> + %38 = tt.reshape %22 : (tensor<64xi32>) -> tensor<1x64xi32> + %39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32> + %40 = arith.muli %38, %39 : tensor<1x64xi32> + %41 = tt.broadcast %37 : (tensor<32x1xi32>) -> tensor<32x64xi32> + %42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<32x64xi32> + %43 = arith.addi %41, %42 : tensor<32x64xi32> + %44 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<32x64x!tt.ptr> + %45 = tt.getelementptr %44, %43, : tensor<32x64x!tt.ptr> + %46 = tt.broadcast %cst_3 : (f32) -> tensor<64x64xf32> + %47 = arith.index_cast %arg5 : i32 to index + %48:3 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr>) { + %78 = tt.load %arg11, %cst_2, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16> + %79 = tt.load %arg12, %cst_0, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16> + %80 = tt.broadcast %cst_3 : (f32) -> tensor<64x64xf32> + %81 = tt.dot %78, %79, %80 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32> + %82 = arith.addf %arg10, %81 : tensor<64x64xf32> + %83 = tt.broadcast %c32_i32 : (i32) -> tensor<64x32xi32> + %84 = tt.getelementptr %arg11, %83, : tensor<64x32x!tt.ptr> + %85 = arith.muli %arg7, %c32_i32 : i32 + %86 = tt.broadcast %85 : (i32) -> tensor<32x64xi32> + %87 = tt.getelementptr %arg12, %86, : tensor<32x64x!tt.ptr> + scf.yield %82, %84, %87 : tensor<64x64xf32>, tensor<64x32x!tt.ptr>, tensor<32x64x!tt.ptr> + } + %49 = arith.truncf %48#0 : tensor<64x64xf32> to tensor<64x64xf16> + %50 = arith.muli %12, %c64_i32 : i32 + %51 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %52 = tt.broadcast %50 : (i32) -> tensor<64xi32> + %53 = arith.addi %52, %51 : tensor<64xi32> + %54 = arith.muli %14, %c64_i32 : i32 + %55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %56 = tt.broadcast %54 : (i32) -> tensor<64xi32> + %57 = arith.addi %56, %55 : tensor<64xi32> + %58 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32> + %59 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32> + %60 = arith.muli %59, %58 : tensor<64x1xi32> + %61 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr> + %62 = tt.getelementptr %61, %60, : tensor<64x1x!tt.ptr> + %63 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32> + %64 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32> + %65 = arith.muli %63, %64 : tensor<1x64xi32> + %66 = tt.broadcast %62 : (tensor<64x1x!tt.ptr>) -> tensor<64x64x!tt.ptr> + %67 = tt.broadcast %65 : (tensor<1x64xi32>) -> tensor<64x64xi32> + %68 = tt.getelementptr %66, %67, : tensor<64x64x!tt.ptr> + %69 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32> + %70 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32> + %71 = arith.cmpi slt, %69, %70 : tensor<64x1xi32> + %72 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32> + %73 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32> + %74 = arith.cmpi slt, %72, %73 : tensor<1x64xi32> + %75 = tt.broadcast %71 : (tensor<64x1xi1>) -> tensor<64x64xi1> + %76 = tt.broadcast %74 : (tensor<1x64xi1>) -> tensor<64x64xi1> + %77 = arith.andi %75, %76 : tensor<64x64xi1> + tt.store %68, %49, %77, : tensor<64x64xf16> + return + } + func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 { + %c63_i32 = arith.constant 63 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = arith.addi %arg0, %c63_i32 : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + return %1 : i32 + } + func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 { + %c8_i32 = arith.constant 8 : i32 + %0 = arith.cmpi slt, %arg0, %c8_i32 : i32 + %1 = select %0, %arg0, %c8_i32 : i32 + return %1 : i32 + } } diff --git a/rewrite-test/jit/matmul/matmul.py b/rewrite-test/jit/matmul/matmul.py index 8f0700aab..69897a8d1 100644 --- a/rewrite-test/jit/matmul/matmul.py +++ b/rewrite-test/jit/matmul/matmul.py @@ -1,5 +1,7 @@ import triton import triton.language as tl +import triton._C.libtriton.triton as _triton + import torch @@ -91,5 +93,14 @@ mod, ctx = matmul_kernel.compile_to_ttir( 64, 64, 32, 8, grid=(2,) ) + +assert mod.verify() +mod.dump() + +pm = _triton.ir.pass_manager(ctx) +pm.add_inliner_pass() +pm.add_canonicalizer_pass() +pm.run(mod) + +assert mod.verify() mod.dump() -mod.verify()