update mma encoding & triton-opt
This commit is contained in:
@@ -4,17 +4,15 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
|||||||
add_llvm_executable(triton-opt triton-opt.cpp)
|
add_llvm_executable(triton-opt triton-opt.cpp)
|
||||||
|
|
||||||
# TODO: what's this?
|
# TODO: what's this?
|
||||||
llvm_update_compile_flags(triton-opt)
|
# llvm_update_compile_flags(triton-opt)
|
||||||
target_link_libraries(triton-opt PRIVATE
|
target_link_libraries(triton-opt PRIVATE
|
||||||
|
TritonTransforms
|
||||||
|
TritonGPUTransforms
|
||||||
${dialect_libs}
|
${dialect_libs}
|
||||||
${conversion_libs}
|
${conversion_libs}
|
||||||
MLIROptLib
|
MLIROptLib
|
||||||
|
MLIRPass
|
||||||
TritonIR
|
MLIRTransforms
|
||||||
TritonTransforms
|
|
||||||
|
|
||||||
TritonGPUIR
|
|
||||||
TritonGPUTransforms
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mlir_check_all_link_libraries(triton-opt)
|
mlir_check_all_link_libraries(triton-opt)
|
||||||
|
@@ -1,6 +1,9 @@
|
|||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
|
||||||
|
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/InitAllPasses.h"
|
#include "mlir/InitAllPasses.h"
|
||||||
#include "mlir/Support/MlirOptMain.h"
|
#include "mlir/Support/MlirOptMain.h"
|
||||||
@@ -8,7 +11,10 @@
|
|||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
mlir::registerAllPasses();
|
mlir::registerAllPasses();
|
||||||
|
mlir::registerTritonPasses();
|
||||||
|
mlir::registerTritonGPUPasses();
|
||||||
|
|
||||||
|
// TODO: register Triton & TritonGPU passes
|
||||||
mlir::DialectRegistry registry;
|
mlir::DialectRegistry registry;
|
||||||
registry.insert<mlir::triton::TritonDialect,
|
registry.insert<mlir::triton::TritonDialect,
|
||||||
mlir::triton::gpu::TritonGPUDialect,
|
mlir::triton::gpu::TritonGPUDialect,
|
||||||
|
@@ -126,7 +126,7 @@ def TT_GEPOp : TT_Op<"getelementptr",
|
|||||||
|
|
||||||
let results = (outs TT_PtrTensor:$result);
|
let results = (outs TT_PtrTensor:$result);
|
||||||
|
|
||||||
let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)";
|
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -9,6 +9,10 @@ namespace triton {
|
|||||||
std::unique_ptr<Pass> createCombineOpsPass();
|
std::unique_ptr<Pass> createCombineOpsPass();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -14,7 +14,7 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
|
|||||||
load(ptrs, broadcast(cond), other)
|
load(ptrs, broadcast(cond), other)
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let constructor = "mlir::triton::createCombineOpsPass";
|
let constructor = "mlir::triton::createCombineOpsPass()";
|
||||||
|
|
||||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||||
/*SelectOp*/"mlir::StandardOpsDialect"];
|
/*SelectOp*/"mlir::StandardOpsDialect"];
|
||||||
|
@@ -130,7 +130,7 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
|||||||
// TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout)
|
// TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout)
|
||||||
ArrayRefParameter<"unsigned">:$shapePerTile,
|
ArrayRefParameter<"unsigned">:$shapePerTile,
|
||||||
// TODO: should Distributed layout also
|
// TODO: should Distributed layout also
|
||||||
ArrayRefParameter<"unsigned">:$reptitions,
|
ArrayRefParameter<"unsigned">:$repetitions,
|
||||||
ArrayRefParameter<"unsigned">:$contigPerThread
|
ArrayRefParameter<"unsigned">:$contigPerThread
|
||||||
// "AffineMap":$warpOrdering,
|
// "AffineMap":$warpOrdering,
|
||||||
// "AffineMap":$blockOrdering
|
// "AffineMap":$blockOrdering
|
||||||
|
@@ -26,6 +26,8 @@ def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
|
|||||||
let arguments = (ins TT_Tensor:$src);
|
let arguments = (ins TT_Tensor:$src);
|
||||||
|
|
||||||
let results = (outs TT_Tensor:$result);
|
let results = (outs TT_Tensor:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||||
|
@@ -4,19 +4,15 @@
|
|||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages);
|
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
|
||||||
|
|
||||||
namespace triton {
|
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
|
||||||
namespace gpu {
|
|
||||||
std::unique_ptr<Pass> createCombineOpsPass();
|
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// /// Generate the code for registering passes.
|
/// Generate the code for registering passes.
|
||||||
// #define GEN_PASS_REGISTRATION
|
#define GEN_PASS_REGISTRATION
|
||||||
// #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
#endif
|
#endif
|
||||||
|
@@ -19,7 +19,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
|||||||
...
|
...
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let constructor = "mlir::triton::gpu::createPipelinePass";
|
let constructor = "mlir::createTritonGPUPipelinePass()";
|
||||||
|
|
||||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||||
"mlir::scf::SCFDialect",
|
"mlir::scf::SCFDialect",
|
||||||
@@ -45,7 +45,7 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
|||||||
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
|
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let constructor = "mlir::triton::gpu::createCombineOpsPass";
|
let constructor = "mlir::createTritonGPUCombineOpsPass()";
|
||||||
|
|
||||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||||
"mlir::triton::TritonDialect"];
|
"mlir::triton::TritonDialect"];
|
||||||
@@ -56,7 +56,7 @@ def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> {
|
|||||||
|
|
||||||
let description = [{}];
|
let description = [{}];
|
||||||
|
|
||||||
let constructor = "mlir::triton::gpu::createTritonGPUVerifier";
|
let constructor = "mlir::createTritonGPUVerifier()";
|
||||||
|
|
||||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||||
}
|
}
|
||||||
|
@@ -112,11 +112,65 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
|||||||
|
|
||||||
Attribute
|
Attribute
|
||||||
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||||
llvm_unreachable("Not implemented");
|
if (parser.parseLess().failed())
|
||||||
|
return {};
|
||||||
|
DictionaryAttr dict;
|
||||||
|
if (parser.parseAttribute(dict).failed())
|
||||||
|
return {};
|
||||||
|
if (parser.parseGreater().failed())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
SmallVector<unsigned, 2> fragmentPerWarp;
|
||||||
|
SmallVector<unsigned, 2> shapePerWarp;
|
||||||
|
SmallVector<unsigned, 2> warpPerTile;
|
||||||
|
SmallVector<unsigned, 2> shapePerTile;
|
||||||
|
SmallVector<unsigned, 2> repetitions;
|
||||||
|
SmallVector<unsigned, 2> contigPerThread;
|
||||||
|
|
||||||
|
for (const NamedAttribute &attr : dict) {
|
||||||
|
if (attr.getName() == "fragmentPerWarp") {
|
||||||
|
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp").failed())
|
||||||
|
return {};
|
||||||
|
} else if (attr.getName() == "shapePerWarp") {
|
||||||
|
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp").failed())
|
||||||
|
return {};
|
||||||
|
} else if (attr.getName() == "warpPerTile") {
|
||||||
|
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
|
||||||
|
return {};
|
||||||
|
} else if (attr.getName() == "shapePerTile") {
|
||||||
|
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile").failed())
|
||||||
|
return {};
|
||||||
|
} else if (attr.getName() == "repetitions") {
|
||||||
|
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
|
||||||
|
return {};
|
||||||
|
} else if (attr.getName() == "contigPerThread") {
|
||||||
|
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread").failed())
|
||||||
|
return {};
|
||||||
|
} else {
|
||||||
|
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||||
|
<< attr.getName().strref();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
|
||||||
|
fragmentPerWarp,
|
||||||
|
shapePerWarp,
|
||||||
|
warpPerTile,
|
||||||
|
shapePerTile,
|
||||||
|
repetitions,
|
||||||
|
contigPerThread);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||||
llvm_unreachable("Not implemented");
|
printer << "<{"
|
||||||
|
<< "fragmentPerWarp = [" << getFragmentPerWarp() << "]"
|
||||||
|
<< ", shapePerWarp = [" << getShapePerWarp() << "]"
|
||||||
|
<< ", warpPerTile = [" << getWarpPerTile() << "]"
|
||||||
|
<< ", shapePerTile = [" << getShapePerTile() << "]"
|
||||||
|
<< ", repetitions = [" << getRepetitions() << "]"
|
||||||
|
<< ", contigPerThread = [" << getContigPerThread() << "]"
|
||||||
|
<< "}>";
|
||||||
}
|
}
|
||||||
|
|
||||||
Attribute
|
Attribute
|
||||||
|
@@ -45,6 +45,6 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<Pass> triton::gpu::createCombineOpsPass() {
|
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
|
||||||
return std::make_unique<TritonGPUCombineOpsPass>();
|
return std::make_unique<TritonGPUCombineOpsPass>();
|
||||||
}
|
}
|
||||||
|
@@ -100,6 +100,6 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<Pass> triton::gpu::createTritonGPUVerifier() {
|
std::unique_ptr<Pass> mlir::createTritonGPUVerifier() {
|
||||||
return std::make_unique<TritonGPUVerifier>();
|
return std::make_unique<TritonGPUVerifier>();
|
||||||
}
|
}
|
||||||
|
@@ -1353,10 +1353,10 @@ void init_triton_ir(py::module &&m) {
|
|||||||
self.addPass(mlir::createTritonGPUPipelinePass(numStages));
|
self.addPass(mlir::createTritonGPUPipelinePass(numStages));
|
||||||
})
|
})
|
||||||
.def("add_triton_gpu_combine_pass", [](mlir::PassManager &self) {
|
.def("add_triton_gpu_combine_pass", [](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::triton::gpu::createCombineOpsPass());
|
self.addPass(mlir::createTritonGPUCombineOpsPass());
|
||||||
})
|
})
|
||||||
.def("add_triton_gpu_verifier_pass", [](mlir::PassManager &self) {
|
.def("add_triton_gpu_verifier_pass", [](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::triton::gpu::createTritonGPUVerifier());
|
self.addPass(mlir::createTritonGPUVerifier());
|
||||||
})
|
})
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user