Add Triton CombineOps

This commit is contained in:
Yan Da
2022-04-27 13:45:56 +08:00
parent 9e304cf79d
commit 74585fb970
9 changed files with 339 additions and 134 deletions

View File

@@ -131,6 +131,7 @@ def TT_EvictionPolicyAttr : I32EnumAttr<
def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr",
"getPointerTypeFromTensor($_self)">,
@@ -161,6 +162,7 @@ def TT_LoadOp : TT_Op<"load",
def TT_StoreOp : TT_Op<"store",
[SameOperandsShape,
MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type",
"value", "ptr",
"getPointerTypeFromTensor($_self)">,
@@ -214,6 +216,8 @@ def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
let results = (outs TT_Type:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> {
@@ -259,6 +263,8 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
let results = (outs TT_FpIntTensor:$d);
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
// let hasCanonicalizer = 1;
}
// reduction

View File

@@ -0,0 +1,12 @@
#ifndef TRITON_TRANSFORMS_PASSES_H_
#define TRITON_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
std::unique_ptr<Pass> createCombineOpsPass();
// // Registration
// #define GEN_PASS_REGISTRATION
// #include
#endif // TRITON_TRANSFORMS_PASSES_H_

View File

@@ -1,3 +1,4 @@
# add_subdirectory(codegen)
add_subdirectory(driver)
add_subdirectory(ir)
# add_subdirectory(transforms)

View File

@@ -55,16 +55,12 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
dce.run(ir);
disassociate.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
peephole.run(ir);
dce.run(ir);
if (target->is_gpu())
cts.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
coalesce.run(ir);
dce.run(ir);
align.run(ir);
@@ -72,14 +68,10 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
if (target->is_gpu())
cts.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
peephole.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
align.run(ir); axes.run(ir); layouts.run(ir);
swizzle.run(ir);
liveness.run(ir);
allocation.run(ir);

View File

@@ -95,5 +95,18 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::
state.addTypes({resultType});
}
//-- DotOp --
//-- BroadcastOp --
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
return SplatElementsAttr::get(shapedType, {constOperand.getValue()});
}
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,43 @@
#include "triton/transforms/Passes.h"
#include <memory>
using namespace mlir;
namespace {
// <patterns>
struct CombineDotOp : public RewritePattern {
explicit CombineDotOp(MLIRContext *context)
: RewritePattern(/*rootName*/FuncOp::getOperationName(), /*Benefit*/1, context);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
//
}
};
// </patterns>
/// Passes.td (?)
struct CombineOpsPass { // : public mlir::OperationPass<FuncOp>
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
RewritePatternSet patterns(context);
patterns.add<CombineDotOp>();
patterns.add<CombineSelectMaskedLoadOp>();
patterns.add<CombineGEPOp>();
patterns.add<CombineReduceOp>();
// TODO: populate xxx Patter(?)
// TODO: should be use applyPartialConversion(...) ?
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
};
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::triton::createCombineOpsPass() {
return std::make_unique<CombineOpsPass>();
}

View File

@@ -1254,7 +1254,7 @@ void init_triton_ir(py::module &&m) {
.def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
auto loc = self.getUnknownLoc();
if (auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>())
return self.create<mlir::triton::BroadcastOp>(
return self.createOrFold<mlir::triton::BroadcastOp>(
loc, mlir::RankedTensorType::get(shape, argType.getElementType()), arg
);
throw std::runtime_error("arg is not of RankedTensorType, use create_splat");
@@ -1323,12 +1323,15 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) {
self.run(mod.getOperation());
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
return mlir::succeeded(self.run(mod.getOperation()));
})
.def("add_inliner_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createInlinerPass());
})
.def("add_canonicalizer_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createCanonicalizerPass());
})
;
}

View File

@@ -1,127 +1,251 @@
module {
func @matmul_kernel(%arg0: !triton.ptr<f16>, %arg1: !triton.ptr<f16>, %arg2: !triton.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%0 = triton.get_program_id {axis = 0 : i32} : i32
%c64_i32 = arith.constant 64 : i32
%1 = arith.addi %arg3, %c64_i32 : i32
%c1_i32 = arith.constant 1 : i32
%2 = arith.subi %1, %c1_i32 : i32
%c64_i32_0 = arith.constant 64 : i32
%3 = arith.divsi %2, %c64_i32_0 : i32
%c64_i32_1 = arith.constant 64 : i32
%4 = arith.addi %arg4, %c64_i32_1 : i32
%c1_i32_2 = arith.constant 1 : i32
%5 = arith.subi %4, %c1_i32_2 : i32
%c64_i32_3 = arith.constant 64 : i32
%6 = arith.divsi %5, %c64_i32_3 : i32
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = call @"cdiv__i32__1cconstexpr[64]"(%arg3) : (i32) -> i32
%2 = call @"cdiv__i32__1cconstexpr[64]"(%arg4) : (i32) -> i32
%c8_i32 = arith.constant 8 : i32
%7 = arith.muli %6, %c8_i32 : i32
%8 = arith.divsi %0, %7 : i32
%c8_i32_4 = arith.constant 8 : i32
%9 = arith.muli %8, %c8_i32_4 : i32
%10 = arith.subi %3, %9 : i32
%c8_i32_5 = arith.constant 8 : i32
%11 = arith.cmpi slt, %10, %c8_i32_5 : i32
%c8_i32_6 = arith.constant 8 : i32
%12 = select %11, %10, %c8_i32_6 : i32
%13 = arith.remsi %0, %12 : i32
%14 = arith.addi %9, %13 : i32
%15 = arith.remsi %0, %7 : i32
%16 = arith.divsi %15, %12 : i32
%c64_i32_7 = arith.constant 64 : i32
%17 = arith.muli %14, %c64_i32_7 : i32
%18 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%19 = triton.broadcast %17 : (i32) -> tensor<64xi32>
%20 = arith.addi %19, %18 : tensor<64xi32>
%c64_i32_8 = arith.constant 64 : i32
%21 = arith.muli %16, %c64_i32_8 : i32
%22 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%23 = triton.broadcast %21 : (i32) -> tensor<64xi32>
%24 = arith.addi %23, %22 : tensor<64xi32>
%25 = triton.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%26 = triton.reshape %20 : (tensor<64xi32>) -> tensor<64x1xi32>
%27 = triton.broadcast %arg6 : (i32) -> tensor<64x1xi32>
%28 = arith.muli %26, %27 : tensor<64x1xi32>
%29 = triton.reshape %25 : (tensor<32xi32>) -> tensor<1x32xi32>
%c1_i32_9 = arith.constant 1 : i32
%30 = triton.broadcast %c1_i32_9 : (i32) -> tensor<1x32xi32>
%31 = arith.muli %29, %30 : tensor<1x32xi32>
%32 = triton.broadcast %28 : (tensor<64x1xi32>) -> tensor<64x32xi32>
%33 = triton.broadcast %31 : (tensor<1x32xi32>) -> tensor<64x32xi32>
%34 = arith.addi %32, %33 : tensor<64x32xi32>
%35 = triton.broadcast %arg0 : (!triton.ptr<f16>) -> tensor<64x32x!triton.ptr<f16>>
%36 = triton.getelementptr %35, %34, : tensor<64x32x!triton.ptr<f16>>
%37 = triton.reshape %25 : (tensor<32xi32>) -> tensor<32x1xi32>
%38 = triton.broadcast %arg7 : (i32) -> tensor<32x1xi32>
%39 = arith.muli %37, %38 : tensor<32x1xi32>
%40 = triton.reshape %24 : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_10 = arith.constant 1 : i32
%41 = triton.broadcast %c1_i32_10 : (i32) -> tensor<1x64xi32>
%42 = arith.muli %40, %41 : tensor<1x64xi32>
%43 = triton.broadcast %39 : (tensor<32x1xi32>) -> tensor<32x64xi32>
%44 = triton.broadcast %42 : (tensor<1x64xi32>) -> tensor<32x64xi32>
%45 = arith.addi %43, %44 : tensor<32x64xi32>
%46 = triton.broadcast %arg1 : (!triton.ptr<f16>) -> tensor<32x64x!triton.ptr<f16>>
%47 = triton.getelementptr %46, %45, : tensor<32x64x!triton.ptr<f16>>
%3 = arith.muli %2, %c8_i32 : i32
%4 = arith.divsi %0, %3 : i32
%c8_i32_0 = arith.constant 8 : i32
%5 = arith.muli %4, %c8_i32_0 : i32
%6 = arith.subi %1, %5 : i32
%7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32
%8 = arith.remsi %0, %7 : i32
%9 = arith.addi %5, %8 : i32
%10 = arith.remsi %0, %3 : i32
%11 = arith.divsi %10, %7 : i32
%c64_i32 = arith.constant 64 : i32
%12 = arith.muli %9, %c64_i32 : i32
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%14 = tt.broadcast %12 : (i32) -> tensor<64xi32>
%15 = arith.addi %14, %13 : tensor<64xi32>
%c64_i32_1 = arith.constant 64 : i32
%16 = arith.muli %11, %c64_i32_1 : i32
%17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%18 = tt.broadcast %16 : (i32) -> tensor<64xi32>
%19 = arith.addi %18, %17 : tensor<64xi32>
%20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%21 = tt.reshape %15 : (tensor<64xi32>) -> tensor<64x1xi32>
%22 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
%23 = arith.muli %21, %22 : tensor<64x1xi32>
%24 = tt.reshape %20 : (tensor<32xi32>) -> tensor<1x32xi32>
%c1_i32 = arith.constant 1 : i32
%25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32>
%26 = arith.muli %24, %25 : tensor<1x32xi32>
%27 = tt.broadcast %23 : (tensor<64x1xi32>) -> tensor<64x32xi32>
%28 = tt.broadcast %26 : (tensor<1x32xi32>) -> tensor<64x32xi32>
%29 = arith.addi %27, %28 : tensor<64x32xi32>
%30 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
%31 = tt.getelementptr %30, %29, : tensor<64x32x!tt.ptr<f16>>
%32 = tt.reshape %20 : (tensor<32xi32>) -> tensor<32x1xi32>
%33 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
%34 = arith.muli %32, %33 : tensor<32x1xi32>
%35 = tt.reshape %19 : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_2 = arith.constant 1 : i32
%36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x64xi32>
%37 = arith.muli %35, %36 : tensor<1x64xi32>
%38 = tt.broadcast %34 : (tensor<32x1xi32>) -> tensor<32x64xi32>
%39 = tt.broadcast %37 : (tensor<1x64xi32>) -> tensor<32x64xi32>
%40 = arith.addi %38, %39 : tensor<32x64xi32>
%41 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
%42 = tt.getelementptr %41, %40, : tensor<32x64x!tt.ptr<f16>>
%cst = arith.constant 0.000000e+00 : f32
%48 = triton.broadcast %cst : (f32) -> tensor<64x64xf32>
%43 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%49 = arith.index_cast %c0_i32 : i32 to index
%50 = arith.index_cast %arg5 : i32 to index
%51 = arith.index_cast %c32_i32 : i32 to index
%52:3 = scf.for %arg9 = %49 to %50 step %51 iter_args(%arg10 = %48, %arg11 = %36, %arg12 = %47) -> (tensor<64x64xf32>, tensor<64x32x!triton.ptr<f16>>, tensor<32x64x!triton.ptr<f16>>) {
%cst_14 = arith.constant dense<true> : tensor<64x32xi1>
%cst_15 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
%82 = triton.load %arg11, %cst_14, %cst_15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
%cst_16 = arith.constant dense<true> : tensor<32x64xi1>
%cst_17 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
%83 = triton.load %arg12, %cst_16, %cst_17 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
%cst_18 = arith.constant 0.000000e+00 : f32
%84 = triton.broadcast %cst_18 : (f32) -> tensor<64x64xf32>
%85 = triton.dot %82, %83, %84 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
%86 = arith.addf %arg10, %85 : tensor<64x64xf32>
%c32_i32_19 = arith.constant 32 : i32
%87 = triton.broadcast %c32_i32_19 : (i32) -> tensor<64x32xi32>
%88 = triton.getelementptr %arg11, %87, : tensor<64x32x!triton.ptr<f16>>
%c32_i32_20 = arith.constant 32 : i32
%89 = arith.muli %arg7, %c32_i32_20 : i32
%90 = triton.broadcast %89 : (i32) -> tensor<32x64xi32>
%91 = triton.getelementptr %arg12, %90, : tensor<32x64x!triton.ptr<f16>>
scf.yield %86, %88, %91 : tensor<64x64xf32>, tensor<64x32x!triton.ptr<f16>>, tensor<32x64x!triton.ptr<f16>>
%44 = arith.index_cast %c0_i32 : i32 to index
%45 = arith.index_cast %arg5 : i32 to index
%46 = arith.index_cast %c32_i32 : i32 to index
%47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
%cst_6 = arith.constant dense<true> : tensor<64x32xi1>
%cst_7 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
%77 = tt.load %arg11, %cst_6, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
%cst_8 = arith.constant dense<true> : tensor<32x64xi1>
%cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
%78 = tt.load %arg12, %cst_8, %cst_9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
%cst_10 = arith.constant 0.000000e+00 : f32
%79 = tt.broadcast %cst_10 : (f32) -> tensor<64x64xf32>
%80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
%81 = arith.addf %arg10, %80 : tensor<64x64xf32>
%c32_i32_11 = arith.constant 32 : i32
%82 = tt.broadcast %c32_i32_11 : (i32) -> tensor<64x32xi32>
%83 = tt.getelementptr %arg11, %82, : tensor<64x32x!tt.ptr<f16>>
%c32_i32_12 = arith.constant 32 : i32
%84 = arith.muli %arg7, %c32_i32_12 : i32
%85 = tt.broadcast %84 : (i32) -> tensor<32x64xi32>
%86 = tt.getelementptr %arg12, %85, : tensor<32x64x!tt.ptr<f16>>
scf.yield %81, %83, %86 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
}
%53 = arith.truncf %52#0 : tensor<64x64xf32> to tensor<64x64xf16>
%c64_i32_11 = arith.constant 64 : i32
%54 = arith.muli %14, %c64_i32_11 : i32
%55 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%56 = triton.broadcast %54 : (i32) -> tensor<64xi32>
%57 = arith.addi %56, %55 : tensor<64xi32>
%c64_i32_12 = arith.constant 64 : i32
%58 = arith.muli %16, %c64_i32_12 : i32
%59 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%60 = triton.broadcast %58 : (i32) -> tensor<64xi32>
%61 = arith.addi %60, %59 : tensor<64xi32>
%62 = triton.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32>
%63 = triton.broadcast %arg8 : (i32) -> tensor<64x1xi32>
%64 = arith.muli %63, %62 : tensor<64x1xi32>
%65 = triton.broadcast %arg2 : (!triton.ptr<f16>) -> tensor<64x1x!triton.ptr<f16>>
%66 = triton.getelementptr %65, %64, : tensor<64x1x!triton.ptr<f16>>
%67 = triton.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_13 = arith.constant 1 : i32
%68 = triton.broadcast %c1_i32_13 : (i32) -> tensor<1x64xi32>
%69 = arith.muli %67, %68 : tensor<1x64xi32>
%70 = triton.broadcast %66 : (tensor<64x1x!triton.ptr<f16>>) -> tensor<64x64x!triton.ptr<f16>>
%71 = triton.broadcast %69 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%72 = triton.getelementptr %70, %71, : tensor<64x64x!triton.ptr<f16>>
%73 = triton.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32>
%74 = triton.broadcast %arg3 : (i32) -> tensor<64x1xi32>
%75 = arith.cmpi slt, %73, %74 : tensor<64x1xi32>
%76 = triton.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32>
%77 = triton.broadcast %arg4 : (i32) -> tensor<1x64xi32>
%78 = arith.cmpi slt, %76, %77 : tensor<1x64xi32>
%79 = triton.broadcast %75 : (tensor<64x1xi1>) -> tensor<64x64xi1>
%80 = triton.broadcast %78 : (tensor<1x64xi1>) -> tensor<64x64xi1>
%81 = arith.andi %79, %80 : tensor<64x64xi1>
triton.store %72, %53, %81, : tensor<64x64xf16>
%48 = arith.truncf %47#0 : tensor<64x64xf32> to tensor<64x64xf16>
%c64_i32_3 = arith.constant 64 : i32
%49 = arith.muli %9, %c64_i32_3 : i32
%50 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%51 = tt.broadcast %49 : (i32) -> tensor<64xi32>
%52 = arith.addi %51, %50 : tensor<64xi32>
%c64_i32_4 = arith.constant 64 : i32
%53 = arith.muli %11, %c64_i32_4 : i32
%54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%55 = tt.broadcast %53 : (i32) -> tensor<64xi32>
%56 = arith.addi %55, %54 : tensor<64xi32>
%57 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32>
%58 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
%59 = arith.muli %58, %57 : tensor<64x1xi32>
%60 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
%61 = tt.getelementptr %60, %59, : tensor<64x1x!tt.ptr<f16>>
%62 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32>
%c1_i32_5 = arith.constant 1 : i32
%63 = tt.broadcast %c1_i32_5 : (i32) -> tensor<1x64xi32>
%64 = arith.muli %62, %63 : tensor<1x64xi32>
%65 = tt.broadcast %61 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
%66 = tt.broadcast %64 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%67 = tt.getelementptr %65, %66, : tensor<64x64x!tt.ptr<f16>>
%68 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32>
%69 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
%70 = arith.cmpi slt, %68, %69 : tensor<64x1xi32>
%71 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32>
%72 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
%73 = arith.cmpi slt, %71, %72 : tensor<1x64xi32>
%74 = tt.broadcast %70 : (tensor<64x1xi1>) -> tensor<64x64xi1>
%75 = tt.broadcast %73 : (tensor<1x64xi1>) -> tensor<64x64xi1>
%76 = arith.andi %74, %75 : tensor<64x64xi1>
tt.store %67, %48, %76, : tensor<64x64xf16>
return
}
func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 {
%c64_i32 = arith.constant 64 : i32
%0 = arith.addi %arg0, %c64_i32 : i32
%c1_i32 = arith.constant 1 : i32
%1 = arith.subi %0, %c1_i32 : i32
%c64_i32_0 = arith.constant 64 : i32
%2 = arith.divsi %1, %c64_i32_0 : i32
return %2 : i32
}
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
%c8_i32 = arith.constant 8 : i32
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
%c8_i32_0 = arith.constant 8 : i32
%1 = select %0, %arg0, %c8_i32_0 : i32
return %1 : i32
}
}
module {
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
%c1_i32 = arith.constant 1 : i32
%c32_i32 = arith.constant 32 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
%cst_0 = arith.constant dense<true> : tensor<32x64xi1>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
%cst_2 = arith.constant dense<true> : tensor<64x32xi1>
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%cst_3 = arith.constant 0.000000e+00 : f32
%c64_i32 = arith.constant 64 : i32
%c63_i32 = arith.constant 63 : i32
%c8_i32 = arith.constant 8 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.addi %arg3, %c63_i32 : i32
%2 = arith.divsi %1, %c64_i32 : i32
%3 = arith.addi %arg4, %c63_i32 : i32
%4 = arith.divsi %3, %c64_i32 : i32
%5 = arith.muli %4, %c8_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c8_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.cmpi slt, %8, %c8_i32 : i32
%10 = select %9, %8, %c8_i32 : i32
%11 = arith.remsi %0, %10 : i32
%12 = arith.addi %7, %11 : i32
%13 = arith.remsi %0, %5 : i32
%14 = arith.divsi %13, %10 : i32
%15 = arith.muli %12, %c64_i32 : i32
%16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%17 = tt.broadcast %15 : (i32) -> tensor<64xi32>
%18 = arith.addi %17, %16 : tensor<64xi32>
%19 = arith.muli %14, %c64_i32 : i32
%20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%21 = tt.broadcast %19 : (i32) -> tensor<64xi32>
%22 = arith.addi %21, %20 : tensor<64xi32>
%23 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%24 = tt.reshape %18 : (tensor<64xi32>) -> tensor<64x1xi32>
%25 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
%26 = arith.muli %24, %25 : tensor<64x1xi32>
%27 = tt.reshape %23 : (tensor<32xi32>) -> tensor<1x32xi32>
%28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32>
%29 = arith.muli %27, %28 : tensor<1x32xi32>
%30 = tt.broadcast %26 : (tensor<64x1xi32>) -> tensor<64x32xi32>
%31 = tt.broadcast %29 : (tensor<1x32xi32>) -> tensor<64x32xi32>
%32 = arith.addi %30, %31 : tensor<64x32xi32>
%33 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
%34 = tt.getelementptr %33, %32, : tensor<64x32x!tt.ptr<f16>>
%35 = tt.reshape %23 : (tensor<32xi32>) -> tensor<32x1xi32>
%36 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
%37 = arith.muli %35, %36 : tensor<32x1xi32>
%38 = tt.reshape %22 : (tensor<64xi32>) -> tensor<1x64xi32>
%39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32>
%40 = arith.muli %38, %39 : tensor<1x64xi32>
%41 = tt.broadcast %37 : (tensor<32x1xi32>) -> tensor<32x64xi32>
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<32x64xi32>
%43 = arith.addi %41, %42 : tensor<32x64xi32>
%44 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
%45 = tt.getelementptr %44, %43, : tensor<32x64x!tt.ptr<f16>>
%46 = tt.broadcast %cst_3 : (f32) -> tensor<64x64xf32>
%47 = arith.index_cast %arg5 : i32 to index
%48:3 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
%78 = tt.load %arg11, %cst_2, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
%79 = tt.load %arg12, %cst_0, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
%80 = tt.broadcast %cst_3 : (f32) -> tensor<64x64xf32>
%81 = tt.dot %78, %79, %80 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
%82 = arith.addf %arg10, %81 : tensor<64x64xf32>
%83 = tt.broadcast %c32_i32 : (i32) -> tensor<64x32xi32>
%84 = tt.getelementptr %arg11, %83, : tensor<64x32x!tt.ptr<f16>>
%85 = arith.muli %arg7, %c32_i32 : i32
%86 = tt.broadcast %85 : (i32) -> tensor<32x64xi32>
%87 = tt.getelementptr %arg12, %86, : tensor<32x64x!tt.ptr<f16>>
scf.yield %82, %84, %87 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
}
%49 = arith.truncf %48#0 : tensor<64x64xf32> to tensor<64x64xf16>
%50 = arith.muli %12, %c64_i32 : i32
%51 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%52 = tt.broadcast %50 : (i32) -> tensor<64xi32>
%53 = arith.addi %52, %51 : tensor<64xi32>
%54 = arith.muli %14, %c64_i32 : i32
%55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%56 = tt.broadcast %54 : (i32) -> tensor<64xi32>
%57 = arith.addi %56, %55 : tensor<64xi32>
%58 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32>
%59 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
%60 = arith.muli %59, %58 : tensor<64x1xi32>
%61 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
%62 = tt.getelementptr %61, %60, : tensor<64x1x!tt.ptr<f16>>
%63 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32>
%64 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32>
%65 = arith.muli %63, %64 : tensor<1x64xi32>
%66 = tt.broadcast %62 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
%67 = tt.broadcast %65 : (tensor<1x64xi32>) -> tensor<64x64xi32>
%68 = tt.getelementptr %66, %67, : tensor<64x64x!tt.ptr<f16>>
%69 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32>
%70 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
%71 = arith.cmpi slt, %69, %70 : tensor<64x1xi32>
%72 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32>
%73 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
%74 = arith.cmpi slt, %72, %73 : tensor<1x64xi32>
%75 = tt.broadcast %71 : (tensor<64x1xi1>) -> tensor<64x64xi1>
%76 = tt.broadcast %74 : (tensor<1x64xi1>) -> tensor<64x64xi1>
%77 = arith.andi %75, %76 : tensor<64x64xi1>
tt.store %68, %49, %77, : tensor<64x64xf16>
return
}
func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 {
%c63_i32 = arith.constant 63 : i32
%c64_i32 = arith.constant 64 : i32
%0 = arith.addi %arg0, %c63_i32 : i32
%1 = arith.divsi %0, %c64_i32 : i32
return %1 : i32
}
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
%c8_i32 = arith.constant 8 : i32
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
%1 = select %0, %arg0, %c8_i32 : i32
return %1 : i32
}
}

View File

@@ -1,5 +1,7 @@
import triton
import triton.language as tl
import triton._C.libtriton.triton as _triton
import torch
@@ -91,5 +93,14 @@ mod, ctx = matmul_kernel.compile_to_ttir(
64, 64, 32,
8, grid=(2,)
)
assert mod.verify()
mod.dump()
pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass()
pm.add_canonicalizer_pass()
pm.run(mod)
assert mod.verify()
mod.dump()
mod.verify()