Add Triton CombineOps
This commit is contained in:
@@ -131,6 +131,7 @@ def TT_EvictionPolicyAttr : I32EnumAttr<
|
|||||||
|
|
||||||
def TT_LoadOp : TT_Op<"load",
|
def TT_LoadOp : TT_Op<"load",
|
||||||
[SameOperandsAndResultShape,
|
[SameOperandsAndResultShape,
|
||||||
|
MemoryEffects<[MemRead]>,
|
||||||
TypesMatchWith<"infer ptr type from result type",
|
TypesMatchWith<"infer ptr type from result type",
|
||||||
"result", "ptr",
|
"result", "ptr",
|
||||||
"getPointerTypeFromTensor($_self)">,
|
"getPointerTypeFromTensor($_self)">,
|
||||||
@@ -161,6 +162,7 @@ def TT_LoadOp : TT_Op<"load",
|
|||||||
|
|
||||||
def TT_StoreOp : TT_Op<"store",
|
def TT_StoreOp : TT_Op<"store",
|
||||||
[SameOperandsShape,
|
[SameOperandsShape,
|
||||||
|
MemoryEffects<[MemWrite]>,
|
||||||
TypesMatchWith<"infer ptr type from value type",
|
TypesMatchWith<"infer ptr type from value type",
|
||||||
"value", "ptr",
|
"value", "ptr",
|
||||||
"getPointerTypeFromTensor($_self)">,
|
"getPointerTypeFromTensor($_self)">,
|
||||||
@@ -214,6 +216,8 @@ def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
|
|||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> {
|
def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> {
|
||||||
@@ -259,6 +263,8 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
|||||||
let results = (outs TT_FpIntTensor:$d);
|
let results = (outs TT_FpIntTensor:$d);
|
||||||
|
|
||||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||||
|
|
||||||
|
// let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduction
|
// reduction
|
||||||
|
12
include/triton/transforms/Passes.h
Normal file
12
include/triton/transforms/Passes.h
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
#ifndef TRITON_TRANSFORMS_PASSES_H_
|
||||||
|
#define TRITON_TRANSFORMS_PASSES_H_
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createCombineOpsPass();
|
||||||
|
|
||||||
|
// // Registration
|
||||||
|
// #define GEN_PASS_REGISTRATION
|
||||||
|
// #include
|
||||||
|
|
||||||
|
#endif // TRITON_TRANSFORMS_PASSES_H_
|
@@ -1,3 +1,4 @@
|
|||||||
# add_subdirectory(codegen)
|
# add_subdirectory(codegen)
|
||||||
add_subdirectory(driver)
|
add_subdirectory(driver)
|
||||||
add_subdirectory(ir)
|
add_subdirectory(ir)
|
||||||
|
# add_subdirectory(transforms)
|
||||||
|
@@ -55,16 +55,12 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
|||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
disassociate.run(ir);
|
disassociate.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
align.run(ir);
|
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||||
axes.run(ir);
|
|
||||||
layouts.run(ir);
|
|
||||||
peephole.run(ir);
|
peephole.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
if (target->is_gpu())
|
if (target->is_gpu())
|
||||||
cts.run(ir);
|
cts.run(ir);
|
||||||
align.run(ir);
|
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||||
axes.run(ir);
|
|
||||||
layouts.run(ir);
|
|
||||||
coalesce.run(ir);
|
coalesce.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
align.run(ir);
|
align.run(ir);
|
||||||
@@ -72,14 +68,10 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
|||||||
if (target->is_gpu())
|
if (target->is_gpu())
|
||||||
cts.run(ir);
|
cts.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
align.run(ir);
|
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||||
axes.run(ir);
|
|
||||||
layouts.run(ir);
|
|
||||||
peephole.run(ir);
|
peephole.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
align.run(ir);
|
align.run(ir); axes.run(ir); layouts.run(ir);
|
||||||
axes.run(ir);
|
|
||||||
layouts.run(ir);
|
|
||||||
swizzle.run(ir);
|
swizzle.run(ir);
|
||||||
liveness.run(ir);
|
liveness.run(ir);
|
||||||
allocation.run(ir);
|
allocation.run(ir);
|
||||||
|
@@ -95,5 +95,18 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::
|
|||||||
state.addTypes({resultType});
|
state.addTypes({resultType});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//-- DotOp --
|
||||||
|
|
||||||
|
//-- BroadcastOp --
|
||||||
|
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||||
|
if (!constOperand)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
auto shapedType = getType().cast<ShapedType>();
|
||||||
|
|
||||||
|
return SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
43
lib/transforms/CombineOps.cpp
Normal file
43
lib/transforms/CombineOps.cpp
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#include "triton/transforms/Passes.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// <patterns>
|
||||||
|
struct CombineDotOp : public RewritePattern {
|
||||||
|
explicit CombineDotOp(MLIRContext *context)
|
||||||
|
: RewritePattern(/*rootName*/FuncOp::getOperationName(), /*Benefit*/1, context);
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
//
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// </patterns>
|
||||||
|
|
||||||
|
/// Passes.td (?)
|
||||||
|
struct CombineOpsPass { // : public mlir::OperationPass<FuncOp>
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
|
||||||
|
patterns.add<CombineDotOp>();
|
||||||
|
patterns.add<CombineSelectMaskedLoadOp>();
|
||||||
|
patterns.add<CombineGEPOp>();
|
||||||
|
patterns.add<CombineReduceOp>();
|
||||||
|
|
||||||
|
// TODO: populate xxx Patter(?)
|
||||||
|
|
||||||
|
// TODO: should be use applyPartialConversion(...) ?
|
||||||
|
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
|
||||||
|
signalPassFailure();
|
||||||
|
};
|
||||||
|
};
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> mlir::triton::createCombineOpsPass() {
|
||||||
|
return std::make_unique<CombineOpsPass>();
|
||||||
|
}
|
@@ -1254,7 +1254,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
.def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
if (auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>())
|
if (auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>())
|
||||||
return self.create<mlir::triton::BroadcastOp>(
|
return self.createOrFold<mlir::triton::BroadcastOp>(
|
||||||
loc, mlir::RankedTensorType::get(shape, argType.getElementType()), arg
|
loc, mlir::RankedTensorType::get(shape, argType.getElementType()), arg
|
||||||
);
|
);
|
||||||
throw std::runtime_error("arg is not of RankedTensorType, use create_splat");
|
throw std::runtime_error("arg is not of RankedTensorType, use create_splat");
|
||||||
@@ -1323,12 +1323,15 @@ void init_triton_ir(py::module &&m) {
|
|||||||
|
|
||||||
py::class_<mlir::PassManager>(m, "pass_manager")
|
py::class_<mlir::PassManager>(m, "pass_manager")
|
||||||
.def(py::init<mlir::MLIRContext *>())
|
.def(py::init<mlir::MLIRContext *>())
|
||||||
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) {
|
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
|
||||||
self.run(mod.getOperation());
|
return mlir::succeeded(self.run(mod.getOperation()));
|
||||||
})
|
})
|
||||||
.def("add_inliner_pass", [](mlir::PassManager &self) {
|
.def("add_inliner_pass", [](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::createInlinerPass());
|
self.addPass(mlir::createInlinerPass());
|
||||||
})
|
})
|
||||||
|
.def("add_canonicalizer_pass", [](mlir::PassManager &self) {
|
||||||
|
self.addPass(mlir::createCanonicalizerPass());
|
||||||
|
})
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,127 +1,251 @@
|
|||||||
module {
|
module {
|
||||||
func @matmul_kernel(%arg0: !triton.ptr<f16>, %arg1: !triton.ptr<f16>, %arg2: !triton.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
|
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
|
||||||
%0 = triton.get_program_id {axis = 0 : i32} : i32
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||||
%c64_i32 = arith.constant 64 : i32
|
%1 = call @"cdiv__i32__1cconstexpr[64]"(%arg3) : (i32) -> i32
|
||||||
%1 = arith.addi %arg3, %c64_i32 : i32
|
%2 = call @"cdiv__i32__1cconstexpr[64]"(%arg4) : (i32) -> i32
|
||||||
%c1_i32 = arith.constant 1 : i32
|
|
||||||
%2 = arith.subi %1, %c1_i32 : i32
|
|
||||||
%c64_i32_0 = arith.constant 64 : i32
|
|
||||||
%3 = arith.divsi %2, %c64_i32_0 : i32
|
|
||||||
%c64_i32_1 = arith.constant 64 : i32
|
|
||||||
%4 = arith.addi %arg4, %c64_i32_1 : i32
|
|
||||||
%c1_i32_2 = arith.constant 1 : i32
|
|
||||||
%5 = arith.subi %4, %c1_i32_2 : i32
|
|
||||||
%c64_i32_3 = arith.constant 64 : i32
|
|
||||||
%6 = arith.divsi %5, %c64_i32_3 : i32
|
|
||||||
%c8_i32 = arith.constant 8 : i32
|
%c8_i32 = arith.constant 8 : i32
|
||||||
%7 = arith.muli %6, %c8_i32 : i32
|
%3 = arith.muli %2, %c8_i32 : i32
|
||||||
%8 = arith.divsi %0, %7 : i32
|
%4 = arith.divsi %0, %3 : i32
|
||||||
%c8_i32_4 = arith.constant 8 : i32
|
%c8_i32_0 = arith.constant 8 : i32
|
||||||
%9 = arith.muli %8, %c8_i32_4 : i32
|
%5 = arith.muli %4, %c8_i32_0 : i32
|
||||||
%10 = arith.subi %3, %9 : i32
|
%6 = arith.subi %1, %5 : i32
|
||||||
%c8_i32_5 = arith.constant 8 : i32
|
%7 = call @"minimum__i32__1cconstexpr[8]"(%6) : (i32) -> i32
|
||||||
%11 = arith.cmpi slt, %10, %c8_i32_5 : i32
|
%8 = arith.remsi %0, %7 : i32
|
||||||
%c8_i32_6 = arith.constant 8 : i32
|
%9 = arith.addi %5, %8 : i32
|
||||||
%12 = select %11, %10, %c8_i32_6 : i32
|
%10 = arith.remsi %0, %3 : i32
|
||||||
%13 = arith.remsi %0, %12 : i32
|
%11 = arith.divsi %10, %7 : i32
|
||||||
%14 = arith.addi %9, %13 : i32
|
%c64_i32 = arith.constant 64 : i32
|
||||||
%15 = arith.remsi %0, %7 : i32
|
%12 = arith.muli %9, %c64_i32 : i32
|
||||||
%16 = arith.divsi %15, %12 : i32
|
%13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||||
%c64_i32_7 = arith.constant 64 : i32
|
%14 = tt.broadcast %12 : (i32) -> tensor<64xi32>
|
||||||
%17 = arith.muli %14, %c64_i32_7 : i32
|
%15 = arith.addi %14, %13 : tensor<64xi32>
|
||||||
%18 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
%c64_i32_1 = arith.constant 64 : i32
|
||||||
%19 = triton.broadcast %17 : (i32) -> tensor<64xi32>
|
%16 = arith.muli %11, %c64_i32_1 : i32
|
||||||
%20 = arith.addi %19, %18 : tensor<64xi32>
|
%17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||||
%c64_i32_8 = arith.constant 64 : i32
|
%18 = tt.broadcast %16 : (i32) -> tensor<64xi32>
|
||||||
%21 = arith.muli %16, %c64_i32_8 : i32
|
%19 = arith.addi %18, %17 : tensor<64xi32>
|
||||||
%22 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
%20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
|
||||||
%23 = triton.broadcast %21 : (i32) -> tensor<64xi32>
|
%21 = tt.reshape %15 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||||
%24 = arith.addi %23, %22 : tensor<64xi32>
|
%22 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
|
||||||
%25 = triton.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
|
%23 = arith.muli %21, %22 : tensor<64x1xi32>
|
||||||
%26 = triton.reshape %20 : (tensor<64xi32>) -> tensor<64x1xi32>
|
%24 = tt.reshape %20 : (tensor<32xi32>) -> tensor<1x32xi32>
|
||||||
%27 = triton.broadcast %arg6 : (i32) -> tensor<64x1xi32>
|
%c1_i32 = arith.constant 1 : i32
|
||||||
%28 = arith.muli %26, %27 : tensor<64x1xi32>
|
%25 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32>
|
||||||
%29 = triton.reshape %25 : (tensor<32xi32>) -> tensor<1x32xi32>
|
%26 = arith.muli %24, %25 : tensor<1x32xi32>
|
||||||
%c1_i32_9 = arith.constant 1 : i32
|
%27 = tt.broadcast %23 : (tensor<64x1xi32>) -> tensor<64x32xi32>
|
||||||
%30 = triton.broadcast %c1_i32_9 : (i32) -> tensor<1x32xi32>
|
%28 = tt.broadcast %26 : (tensor<1x32xi32>) -> tensor<64x32xi32>
|
||||||
%31 = arith.muli %29, %30 : tensor<1x32xi32>
|
%29 = arith.addi %27, %28 : tensor<64x32xi32>
|
||||||
%32 = triton.broadcast %28 : (tensor<64x1xi32>) -> tensor<64x32xi32>
|
%30 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
|
||||||
%33 = triton.broadcast %31 : (tensor<1x32xi32>) -> tensor<64x32xi32>
|
%31 = tt.getelementptr %30, %29, : tensor<64x32x!tt.ptr<f16>>
|
||||||
%34 = arith.addi %32, %33 : tensor<64x32xi32>
|
%32 = tt.reshape %20 : (tensor<32xi32>) -> tensor<32x1xi32>
|
||||||
%35 = triton.broadcast %arg0 : (!triton.ptr<f16>) -> tensor<64x32x!triton.ptr<f16>>
|
%33 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
|
||||||
%36 = triton.getelementptr %35, %34, : tensor<64x32x!triton.ptr<f16>>
|
%34 = arith.muli %32, %33 : tensor<32x1xi32>
|
||||||
%37 = triton.reshape %25 : (tensor<32xi32>) -> tensor<32x1xi32>
|
%35 = tt.reshape %19 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||||
%38 = triton.broadcast %arg7 : (i32) -> tensor<32x1xi32>
|
%c1_i32_2 = arith.constant 1 : i32
|
||||||
%39 = arith.muli %37, %38 : tensor<32x1xi32>
|
%36 = tt.broadcast %c1_i32_2 : (i32) -> tensor<1x64xi32>
|
||||||
%40 = triton.reshape %24 : (tensor<64xi32>) -> tensor<1x64xi32>
|
%37 = arith.muli %35, %36 : tensor<1x64xi32>
|
||||||
%c1_i32_10 = arith.constant 1 : i32
|
%38 = tt.broadcast %34 : (tensor<32x1xi32>) -> tensor<32x64xi32>
|
||||||
%41 = triton.broadcast %c1_i32_10 : (i32) -> tensor<1x64xi32>
|
%39 = tt.broadcast %37 : (tensor<1x64xi32>) -> tensor<32x64xi32>
|
||||||
%42 = arith.muli %40, %41 : tensor<1x64xi32>
|
%40 = arith.addi %38, %39 : tensor<32x64xi32>
|
||||||
%43 = triton.broadcast %39 : (tensor<32x1xi32>) -> tensor<32x64xi32>
|
%41 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
|
||||||
%44 = triton.broadcast %42 : (tensor<1x64xi32>) -> tensor<32x64xi32>
|
%42 = tt.getelementptr %41, %40, : tensor<32x64x!tt.ptr<f16>>
|
||||||
%45 = arith.addi %43, %44 : tensor<32x64xi32>
|
|
||||||
%46 = triton.broadcast %arg1 : (!triton.ptr<f16>) -> tensor<32x64x!triton.ptr<f16>>
|
|
||||||
%47 = triton.getelementptr %46, %45, : tensor<32x64x!triton.ptr<f16>>
|
|
||||||
%cst = arith.constant 0.000000e+00 : f32
|
%cst = arith.constant 0.000000e+00 : f32
|
||||||
%48 = triton.broadcast %cst : (f32) -> tensor<64x64xf32>
|
%43 = tt.broadcast %cst : (f32) -> tensor<64x64xf32>
|
||||||
%c0_i32 = arith.constant 0 : i32
|
%c0_i32 = arith.constant 0 : i32
|
||||||
%c32_i32 = arith.constant 32 : i32
|
%c32_i32 = arith.constant 32 : i32
|
||||||
%49 = arith.index_cast %c0_i32 : i32 to index
|
%44 = arith.index_cast %c0_i32 : i32 to index
|
||||||
%50 = arith.index_cast %arg5 : i32 to index
|
%45 = arith.index_cast %arg5 : i32 to index
|
||||||
%51 = arith.index_cast %c32_i32 : i32 to index
|
%46 = arith.index_cast %c32_i32 : i32 to index
|
||||||
%52:3 = scf.for %arg9 = %49 to %50 step %51 iter_args(%arg10 = %48, %arg11 = %36, %arg12 = %47) -> (tensor<64x64xf32>, tensor<64x32x!triton.ptr<f16>>, tensor<32x64x!triton.ptr<f16>>) {
|
%47:3 = scf.for %arg9 = %44 to %45 step %46 iter_args(%arg10 = %43, %arg11 = %31, %arg12 = %42) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
|
||||||
%cst_14 = arith.constant dense<true> : tensor<64x32xi1>
|
%cst_6 = arith.constant dense<true> : tensor<64x32xi1>
|
||||||
%cst_15 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
|
%cst_7 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
|
||||||
%82 = triton.load %arg11, %cst_14, %cst_15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
|
%77 = tt.load %arg11, %cst_6, %cst_7 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
|
||||||
%cst_16 = arith.constant dense<true> : tensor<32x64xi1>
|
%cst_8 = arith.constant dense<true> : tensor<32x64xi1>
|
||||||
%cst_17 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
|
%cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
|
||||||
%83 = triton.load %arg12, %cst_16, %cst_17 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
|
%78 = tt.load %arg12, %cst_8, %cst_9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
|
||||||
%cst_18 = arith.constant 0.000000e+00 : f32
|
%cst_10 = arith.constant 0.000000e+00 : f32
|
||||||
%84 = triton.broadcast %cst_18 : (f32) -> tensor<64x64xf32>
|
%79 = tt.broadcast %cst_10 : (f32) -> tensor<64x64xf32>
|
||||||
%85 = triton.dot %82, %83, %84 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
|
%80 = tt.dot %77, %78, %79 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
|
||||||
%86 = arith.addf %arg10, %85 : tensor<64x64xf32>
|
%81 = arith.addf %arg10, %80 : tensor<64x64xf32>
|
||||||
%c32_i32_19 = arith.constant 32 : i32
|
%c32_i32_11 = arith.constant 32 : i32
|
||||||
%87 = triton.broadcast %c32_i32_19 : (i32) -> tensor<64x32xi32>
|
%82 = tt.broadcast %c32_i32_11 : (i32) -> tensor<64x32xi32>
|
||||||
%88 = triton.getelementptr %arg11, %87, : tensor<64x32x!triton.ptr<f16>>
|
%83 = tt.getelementptr %arg11, %82, : tensor<64x32x!tt.ptr<f16>>
|
||||||
%c32_i32_20 = arith.constant 32 : i32
|
%c32_i32_12 = arith.constant 32 : i32
|
||||||
%89 = arith.muli %arg7, %c32_i32_20 : i32
|
%84 = arith.muli %arg7, %c32_i32_12 : i32
|
||||||
%90 = triton.broadcast %89 : (i32) -> tensor<32x64xi32>
|
%85 = tt.broadcast %84 : (i32) -> tensor<32x64xi32>
|
||||||
%91 = triton.getelementptr %arg12, %90, : tensor<32x64x!triton.ptr<f16>>
|
%86 = tt.getelementptr %arg12, %85, : tensor<32x64x!tt.ptr<f16>>
|
||||||
scf.yield %86, %88, %91 : tensor<64x64xf32>, tensor<64x32x!triton.ptr<f16>>, tensor<32x64x!triton.ptr<f16>>
|
scf.yield %81, %83, %86 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
|
||||||
}
|
}
|
||||||
%53 = arith.truncf %52#0 : tensor<64x64xf32> to tensor<64x64xf16>
|
%48 = arith.truncf %47#0 : tensor<64x64xf32> to tensor<64x64xf16>
|
||||||
%c64_i32_11 = arith.constant 64 : i32
|
%c64_i32_3 = arith.constant 64 : i32
|
||||||
%54 = arith.muli %14, %c64_i32_11 : i32
|
%49 = arith.muli %9, %c64_i32_3 : i32
|
||||||
%55 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
%50 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||||
%56 = triton.broadcast %54 : (i32) -> tensor<64xi32>
|
%51 = tt.broadcast %49 : (i32) -> tensor<64xi32>
|
||||||
%57 = arith.addi %56, %55 : tensor<64xi32>
|
%52 = arith.addi %51, %50 : tensor<64xi32>
|
||||||
%c64_i32_12 = arith.constant 64 : i32
|
%c64_i32_4 = arith.constant 64 : i32
|
||||||
%58 = arith.muli %16, %c64_i32_12 : i32
|
%53 = arith.muli %11, %c64_i32_4 : i32
|
||||||
%59 = triton.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
%54 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||||
%60 = triton.broadcast %58 : (i32) -> tensor<64xi32>
|
%55 = tt.broadcast %53 : (i32) -> tensor<64xi32>
|
||||||
%61 = arith.addi %60, %59 : tensor<64xi32>
|
%56 = arith.addi %55, %54 : tensor<64xi32>
|
||||||
%62 = triton.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32>
|
%57 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||||
%63 = triton.broadcast %arg8 : (i32) -> tensor<64x1xi32>
|
%58 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
|
||||||
%64 = arith.muli %63, %62 : tensor<64x1xi32>
|
%59 = arith.muli %58, %57 : tensor<64x1xi32>
|
||||||
%65 = triton.broadcast %arg2 : (!triton.ptr<f16>) -> tensor<64x1x!triton.ptr<f16>>
|
%60 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
|
||||||
%66 = triton.getelementptr %65, %64, : tensor<64x1x!triton.ptr<f16>>
|
%61 = tt.getelementptr %60, %59, : tensor<64x1x!tt.ptr<f16>>
|
||||||
%67 = triton.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32>
|
%62 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||||
%c1_i32_13 = arith.constant 1 : i32
|
%c1_i32_5 = arith.constant 1 : i32
|
||||||
%68 = triton.broadcast %c1_i32_13 : (i32) -> tensor<1x64xi32>
|
%63 = tt.broadcast %c1_i32_5 : (i32) -> tensor<1x64xi32>
|
||||||
%69 = arith.muli %67, %68 : tensor<1x64xi32>
|
%64 = arith.muli %62, %63 : tensor<1x64xi32>
|
||||||
%70 = triton.broadcast %66 : (tensor<64x1x!triton.ptr<f16>>) -> tensor<64x64x!triton.ptr<f16>>
|
%65 = tt.broadcast %61 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
|
||||||
%71 = triton.broadcast %69 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
%66 = tt.broadcast %64 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||||
%72 = triton.getelementptr %70, %71, : tensor<64x64x!triton.ptr<f16>>
|
%67 = tt.getelementptr %65, %66, : tensor<64x64x!tt.ptr<f16>>
|
||||||
%73 = triton.reshape %57 : (tensor<64xi32>) -> tensor<64x1xi32>
|
%68 = tt.reshape %52 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||||
%74 = triton.broadcast %arg3 : (i32) -> tensor<64x1xi32>
|
%69 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
|
||||||
%75 = arith.cmpi slt, %73, %74 : tensor<64x1xi32>
|
%70 = arith.cmpi slt, %68, %69 : tensor<64x1xi32>
|
||||||
%76 = triton.reshape %61 : (tensor<64xi32>) -> tensor<1x64xi32>
|
%71 = tt.reshape %56 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||||
%77 = triton.broadcast %arg4 : (i32) -> tensor<1x64xi32>
|
%72 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
|
||||||
%78 = arith.cmpi slt, %76, %77 : tensor<1x64xi32>
|
%73 = arith.cmpi slt, %71, %72 : tensor<1x64xi32>
|
||||||
%79 = triton.broadcast %75 : (tensor<64x1xi1>) -> tensor<64x64xi1>
|
%74 = tt.broadcast %70 : (tensor<64x1xi1>) -> tensor<64x64xi1>
|
||||||
%80 = triton.broadcast %78 : (tensor<1x64xi1>) -> tensor<64x64xi1>
|
%75 = tt.broadcast %73 : (tensor<1x64xi1>) -> tensor<64x64xi1>
|
||||||
%81 = arith.andi %79, %80 : tensor<64x64xi1>
|
%76 = arith.andi %74, %75 : tensor<64x64xi1>
|
||||||
triton.store %72, %53, %81, : tensor<64x64xf16>
|
tt.store %67, %48, %76, : tensor<64x64xf16>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 {
|
||||||
|
%c64_i32 = arith.constant 64 : i32
|
||||||
|
%0 = arith.addi %arg0, %c64_i32 : i32
|
||||||
|
%c1_i32 = arith.constant 1 : i32
|
||||||
|
%1 = arith.subi %0, %c1_i32 : i32
|
||||||
|
%c64_i32_0 = arith.constant 64 : i32
|
||||||
|
%2 = arith.divsi %1, %c64_i32_0 : i32
|
||||||
|
return %2 : i32
|
||||||
|
}
|
||||||
|
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
|
||||||
|
%c8_i32 = arith.constant 8 : i32
|
||||||
|
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
|
||||||
|
%c8_i32_0 = arith.constant 8 : i32
|
||||||
|
%1 = select %0, %arg0, %c8_i32_0 : i32
|
||||||
|
return %1 : i32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
module {
|
||||||
|
func @matmul_kernel__Pfp16_Pfp16_Pfp16_i32_i32_i32_i32_i32_i32__7c1_9c1_11c1_12c64_13c64_14c32_15c8(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
|
||||||
|
%c1_i32 = arith.constant 1 : i32
|
||||||
|
%c32_i32 = arith.constant 32 : i32
|
||||||
|
%cst = arith.constant dense<0.000000e+00> : tensor<32x64xf16>
|
||||||
|
%cst_0 = arith.constant dense<true> : tensor<32x64xi1>
|
||||||
|
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf16>
|
||||||
|
%cst_2 = arith.constant dense<true> : tensor<64x32xi1>
|
||||||
|
%c32 = arith.constant 32 : index
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%cst_3 = arith.constant 0.000000e+00 : f32
|
||||||
|
%c64_i32 = arith.constant 64 : i32
|
||||||
|
%c63_i32 = arith.constant 63 : i32
|
||||||
|
%c8_i32 = arith.constant 8 : i32
|
||||||
|
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||||
|
%1 = arith.addi %arg3, %c63_i32 : i32
|
||||||
|
%2 = arith.divsi %1, %c64_i32 : i32
|
||||||
|
%3 = arith.addi %arg4, %c63_i32 : i32
|
||||||
|
%4 = arith.divsi %3, %c64_i32 : i32
|
||||||
|
%5 = arith.muli %4, %c8_i32 : i32
|
||||||
|
%6 = arith.divsi %0, %5 : i32
|
||||||
|
%7 = arith.muli %6, %c8_i32 : i32
|
||||||
|
%8 = arith.subi %2, %7 : i32
|
||||||
|
%9 = arith.cmpi slt, %8, %c8_i32 : i32
|
||||||
|
%10 = select %9, %8, %c8_i32 : i32
|
||||||
|
%11 = arith.remsi %0, %10 : i32
|
||||||
|
%12 = arith.addi %7, %11 : i32
|
||||||
|
%13 = arith.remsi %0, %5 : i32
|
||||||
|
%14 = arith.divsi %13, %10 : i32
|
||||||
|
%15 = arith.muli %12, %c64_i32 : i32
|
||||||
|
%16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||||
|
%17 = tt.broadcast %15 : (i32) -> tensor<64xi32>
|
||||||
|
%18 = arith.addi %17, %16 : tensor<64xi32>
|
||||||
|
%19 = arith.muli %14, %c64_i32 : i32
|
||||||
|
%20 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||||
|
%21 = tt.broadcast %19 : (i32) -> tensor<64xi32>
|
||||||
|
%22 = arith.addi %21, %20 : tensor<64xi32>
|
||||||
|
%23 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
|
||||||
|
%24 = tt.reshape %18 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||||
|
%25 = tt.broadcast %arg6 : (i32) -> tensor<64x1xi32>
|
||||||
|
%26 = arith.muli %24, %25 : tensor<64x1xi32>
|
||||||
|
%27 = tt.reshape %23 : (tensor<32xi32>) -> tensor<1x32xi32>
|
||||||
|
%28 = tt.broadcast %c1_i32 : (i32) -> tensor<1x32xi32>
|
||||||
|
%29 = arith.muli %27, %28 : tensor<1x32xi32>
|
||||||
|
%30 = tt.broadcast %26 : (tensor<64x1xi32>) -> tensor<64x32xi32>
|
||||||
|
%31 = tt.broadcast %29 : (tensor<1x32xi32>) -> tensor<64x32xi32>
|
||||||
|
%32 = arith.addi %30, %31 : tensor<64x32xi32>
|
||||||
|
%33 = tt.broadcast %arg0 : (!tt.ptr<f16>) -> tensor<64x32x!tt.ptr<f16>>
|
||||||
|
%34 = tt.getelementptr %33, %32, : tensor<64x32x!tt.ptr<f16>>
|
||||||
|
%35 = tt.reshape %23 : (tensor<32xi32>) -> tensor<32x1xi32>
|
||||||
|
%36 = tt.broadcast %arg7 : (i32) -> tensor<32x1xi32>
|
||||||
|
%37 = arith.muli %35, %36 : tensor<32x1xi32>
|
||||||
|
%38 = tt.reshape %22 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||||
|
%39 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32>
|
||||||
|
%40 = arith.muli %38, %39 : tensor<1x64xi32>
|
||||||
|
%41 = tt.broadcast %37 : (tensor<32x1xi32>) -> tensor<32x64xi32>
|
||||||
|
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<32x64xi32>
|
||||||
|
%43 = arith.addi %41, %42 : tensor<32x64xi32>
|
||||||
|
%44 = tt.broadcast %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
|
||||||
|
%45 = tt.getelementptr %44, %43, : tensor<32x64x!tt.ptr<f16>>
|
||||||
|
%46 = tt.broadcast %cst_3 : (f32) -> tensor<64x64xf32>
|
||||||
|
%47 = arith.index_cast %arg5 : i32 to index
|
||||||
|
%48:3 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %46, %arg11 = %34, %arg12 = %45) -> (tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) {
|
||||||
|
%78 = tt.load %arg11, %cst_2, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
|
||||||
|
%79 = tt.load %arg12, %cst_0, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
|
||||||
|
%80 = tt.broadcast %cst_3 : (f32) -> tensor<64x64xf32>
|
||||||
|
%81 = tt.dot %78, %79, %80 {allowTF32 = true} : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
|
||||||
|
%82 = arith.addf %arg10, %81 : tensor<64x64xf32>
|
||||||
|
%83 = tt.broadcast %c32_i32 : (i32) -> tensor<64x32xi32>
|
||||||
|
%84 = tt.getelementptr %arg11, %83, : tensor<64x32x!tt.ptr<f16>>
|
||||||
|
%85 = arith.muli %arg7, %c32_i32 : i32
|
||||||
|
%86 = tt.broadcast %85 : (i32) -> tensor<32x64xi32>
|
||||||
|
%87 = tt.getelementptr %arg12, %86, : tensor<32x64x!tt.ptr<f16>>
|
||||||
|
scf.yield %82, %84, %87 : tensor<64x64xf32>, tensor<64x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
|
||||||
|
}
|
||||||
|
%49 = arith.truncf %48#0 : tensor<64x64xf32> to tensor<64x64xf16>
|
||||||
|
%50 = arith.muli %12, %c64_i32 : i32
|
||||||
|
%51 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||||
|
%52 = tt.broadcast %50 : (i32) -> tensor<64xi32>
|
||||||
|
%53 = arith.addi %52, %51 : tensor<64xi32>
|
||||||
|
%54 = arith.muli %14, %c64_i32 : i32
|
||||||
|
%55 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
|
||||||
|
%56 = tt.broadcast %54 : (i32) -> tensor<64xi32>
|
||||||
|
%57 = arith.addi %56, %55 : tensor<64xi32>
|
||||||
|
%58 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||||
|
%59 = tt.broadcast %arg8 : (i32) -> tensor<64x1xi32>
|
||||||
|
%60 = arith.muli %59, %58 : tensor<64x1xi32>
|
||||||
|
%61 = tt.broadcast %arg2 : (!tt.ptr<f16>) -> tensor<64x1x!tt.ptr<f16>>
|
||||||
|
%62 = tt.getelementptr %61, %60, : tensor<64x1x!tt.ptr<f16>>
|
||||||
|
%63 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||||
|
%64 = tt.broadcast %c1_i32 : (i32) -> tensor<1x64xi32>
|
||||||
|
%65 = arith.muli %63, %64 : tensor<1x64xi32>
|
||||||
|
%66 = tt.broadcast %62 : (tensor<64x1x!tt.ptr<f16>>) -> tensor<64x64x!tt.ptr<f16>>
|
||||||
|
%67 = tt.broadcast %65 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||||
|
%68 = tt.getelementptr %66, %67, : tensor<64x64x!tt.ptr<f16>>
|
||||||
|
%69 = tt.reshape %53 : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||||
|
%70 = tt.broadcast %arg3 : (i32) -> tensor<64x1xi32>
|
||||||
|
%71 = arith.cmpi slt, %69, %70 : tensor<64x1xi32>
|
||||||
|
%72 = tt.reshape %57 : (tensor<64xi32>) -> tensor<1x64xi32>
|
||||||
|
%73 = tt.broadcast %arg4 : (i32) -> tensor<1x64xi32>
|
||||||
|
%74 = arith.cmpi slt, %72, %73 : tensor<1x64xi32>
|
||||||
|
%75 = tt.broadcast %71 : (tensor<64x1xi1>) -> tensor<64x64xi1>
|
||||||
|
%76 = tt.broadcast %74 : (tensor<1x64xi1>) -> tensor<64x64xi1>
|
||||||
|
%77 = arith.andi %75, %76 : tensor<64x64xi1>
|
||||||
|
tt.store %68, %49, %77, : tensor<64x64xf16>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
func @"cdiv__i32__1cconstexpr[64]"(%arg0: i32) -> i32 {
|
||||||
|
%c63_i32 = arith.constant 63 : i32
|
||||||
|
%c64_i32 = arith.constant 64 : i32
|
||||||
|
%0 = arith.addi %arg0, %c63_i32 : i32
|
||||||
|
%1 = arith.divsi %0, %c64_i32 : i32
|
||||||
|
return %1 : i32
|
||||||
|
}
|
||||||
|
func @"minimum__i32__1cconstexpr[8]"(%arg0: i32) -> i32 {
|
||||||
|
%c8_i32 = arith.constant 8 : i32
|
||||||
|
%0 = arith.cmpi slt, %arg0, %c8_i32 : i32
|
||||||
|
%1 = select %0, %arg0, %c8_i32 : i32
|
||||||
|
return %1 : i32
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
import triton._C.libtriton.triton as _triton
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -91,5 +93,14 @@ mod, ctx = matmul_kernel.compile_to_ttir(
|
|||||||
64, 64, 32,
|
64, 64, 32,
|
||||||
8, grid=(2,)
|
8, grid=(2,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert mod.verify()
|
||||||
|
mod.dump()
|
||||||
|
|
||||||
|
pm = _triton.ir.pass_manager(ctx)
|
||||||
|
pm.add_inliner_pass()
|
||||||
|
pm.add_canonicalizer_pass()
|
||||||
|
pm.run(mod)
|
||||||
|
|
||||||
|
assert mod.verify()
|
||||||
mod.dump()
|
mod.dump()
|
||||||
mod.verify()
|
|
||||||
|
Reference in New Issue
Block a user