From 366dddc3bcaf20f6ce9d0ed2dce889f4cebb3c80 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Mon, 6 Jun 2022 21:03:58 +0800 Subject: [PATCH] update mma encoding & triton-opt --- bin/CMakeLists.txt | 12 ++-- bin/triton-opt.cpp | 6 ++ include/triton/Dialect/Triton/IR/TritonOps.td | 2 +- .../triton/Dialect/Triton/Transforms/Passes.h | 4 ++ .../Dialect/Triton/Transforms/Passes.td | 2 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 2 +- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 2 + .../Dialect/TritonGPU/Transforms/Passes.h | 14 ++--- .../Dialect/TritonGPU/Transforms/Passes.td | 6 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 58 ++++++++++++++++++- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 2 +- lib/Dialect/TritonGPU/Transforms/Verifier.cpp | 2 +- python/src/triton.cc | 4 +- 13 files changed, 88 insertions(+), 28 deletions(-) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 0d7fd269a..d1cb8c4c2 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -4,17 +4,15 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) add_llvm_executable(triton-opt triton-opt.cpp) # TODO: what's this? -llvm_update_compile_flags(triton-opt) +# llvm_update_compile_flags(triton-opt) target_link_libraries(triton-opt PRIVATE + TritonTransforms + TritonGPUTransforms ${dialect_libs} ${conversion_libs} MLIROptLib - - TritonIR - TritonTransforms - - TritonGPUIR - TritonGPUTransforms + MLIRPass + MLIRTransforms ) mlir_check_all_link_libraries(triton-opt) diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index 5b8f76a1e..2a41739d7 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -1,6 +1,9 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + #include "mlir/IR/Dialect.h" #include "mlir/InitAllPasses.h" #include "mlir/Support/MlirOptMain.h" @@ -8,7 +11,10 @@ int main(int argc, char **argv) { mlir::registerAllPasses(); + mlir::registerTritonPasses(); + mlir::registerTritonGPUPasses(); + // TODO: register Triton & TritonGPU passes mlir::DialectRegistry registry; registry.insert createCombineOpsPass(); } + +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + } #endif diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td index bd568cb3f..2515057fe 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.td +++ b/include/triton/Dialect/Triton/Transforms/Passes.td @@ -14,7 +14,7 @@ def TritonCombineOps : Pass load(ptrs, broadcast(cond), other) }]; - let constructor = "mlir::triton::createCombineOpsPass"; + let constructor = "mlir::triton::createCombineOpsPass()"; let dependentDialects = ["mlir::arith::ArithmeticDialect", /*SelectOp*/"mlir::StandardOpsDialect"]; diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 383271e1a..38c33c5a4 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -130,7 +130,7 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> { // TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout) ArrayRefParameter<"unsigned">:$shapePerTile, // TODO: should Distributed layout also - ArrayRefParameter<"unsigned">:$reptitions, + ArrayRefParameter<"unsigned">:$repetitions, ArrayRefParameter<"unsigned">:$contigPerThread // "AffineMap":$warpOrdering, // "AffineMap":$blockOrdering diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 62e2684a9..4a5309ae2 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -26,6 +26,8 @@ def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", let arguments = (ins TT_Tensor:$src); let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; } def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index faef0f1b4..3c79ab320 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -4,19 +4,15 @@ #include "mlir/Pass/Pass.h" namespace mlir { -std::unique_ptr createTritonGPUPipelinePass(int numStages); +std::unique_ptr createTritonGPUPipelinePass(int numStages = 2); -namespace triton { -namespace gpu { -std::unique_ptr createCombineOpsPass(); +std::unique_ptr createTritonGPUCombineOpsPass(); std::unique_ptr createTritonGPUVerifier(); -} -} -// /// Generate the code for registering passes. -// #define GEN_PASS_REGISTRATION -// #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" } // namespace mlir #endif diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index af117f328..74a3413f0 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -19,7 +19,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { ... }]; - let constructor = "mlir::triton::gpu::createPipelinePass"; + let constructor = "mlir::createTritonGPUPipelinePass()"; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::scf::SCFDialect", @@ -45,7 +45,7 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT }]; - let constructor = "mlir::triton::gpu::createCombineOpsPass"; + let constructor = "mlir::createTritonGPUCombineOpsPass()"; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::TritonDialect"]; @@ -56,7 +56,7 @@ def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> { let description = [{}]; - let constructor = "mlir::triton::gpu::createTritonGPUVerifier"; + let constructor = "mlir::createTritonGPUVerifier()"; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index e1cf43c65..733e54e86 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -112,11 +112,65 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { - llvm_unreachable("Not implemented"); + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector fragmentPerWarp; + SmallVector shapePerWarp; + SmallVector warpPerTile; + SmallVector shapePerTile; + SmallVector repetitions; + SmallVector contigPerThread; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "fragmentPerWarp") { + if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp").failed()) + return {}; + } else if (attr.getName() == "shapePerWarp") { + if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp").failed()) + return {}; + } else if (attr.getName() == "warpPerTile") { + if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed()) + return {}; + } else if (attr.getName() == "shapePerTile") { + if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile").failed()) + return {}; + } else if (attr.getName() == "repetitions") { + if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed()) + return {}; + } else if (attr.getName() == "contigPerThread") { + if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread").failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + return parser.getChecked(parser.getContext(), + fragmentPerWarp, + shapePerWarp, + warpPerTile, + shapePerTile, + repetitions, + contigPerThread); } void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const { - llvm_unreachable("Not implemented"); + printer << "<{" + << "fragmentPerWarp = [" << getFragmentPerWarp() << "]" + << ", shapePerWarp = [" << getShapePerWarp() << "]" + << ", warpPerTile = [" << getWarpPerTile() << "]" + << ", shapePerTile = [" << getShapePerTile() << "]" + << ", repetitions = [" << getRepetitions() << "]" + << ", contigPerThread = [" << getContigPerThread() << "]" + << "}>"; } Attribute diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index ada418f5c..0052a3975 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -45,6 +45,6 @@ public: } }; -std::unique_ptr triton::gpu::createCombineOpsPass() { +std::unique_ptr mlir::createTritonGPUCombineOpsPass() { return std::make_unique(); } diff --git a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp index 4a6c4b645..16e1d3ec6 100644 --- a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp @@ -100,6 +100,6 @@ private: } }; -std::unique_ptr triton::gpu::createTritonGPUVerifier() { +std::unique_ptr mlir::createTritonGPUVerifier() { return std::make_unique(); } diff --git a/python/src/triton.cc b/python/src/triton.cc index b583be726..1029ea1fd 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1353,10 +1353,10 @@ void init_triton_ir(py::module &&m) { self.addPass(mlir::createTritonGPUPipelinePass(numStages)); }) .def("add_triton_gpu_combine_pass", [](mlir::PassManager &self) { - self.addPass(mlir::triton::gpu::createCombineOpsPass()); + self.addPass(mlir::createTritonGPUCombineOpsPass()); }) .def("add_triton_gpu_verifier_pass", [](mlir::PassManager &self) { - self.addPass(mlir::triton::gpu::createTritonGPUVerifier()); + self.addPass(mlir::createTritonGPUVerifier()); }) ; }