diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index bc2adfd94..60e644cbb 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -8,6 +8,8 @@ def TritonGPU_Dialect : Dialect { let cppNamespace = "::mlir::triton::gpu"; + let hasOperationAttrVerify = 1; + let description = [{ Triton GPU Dialect. }]; diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index e82a3fd67..faef0f1b4 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -9,6 +9,8 @@ std::unique_ptr createTritonGPUPipelinePass(int numStages); namespace triton { namespace gpu { std::unique_ptr createCombineOpsPass(); + +std::unique_ptr createTritonGPUVerifier(); } } diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index e038a4d89..af117f328 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -51,4 +51,14 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { "mlir::triton::TritonDialect"]; } +def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> { + let summary = "verify TritonGPU IR"; + + let description = [{}]; + + let constructor = "mlir::triton::gpu::createTritonGPUVerifier"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; +} + #endif diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d8249ba97..0bb7ece43 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -62,3 +62,12 @@ void TritonGPUDialect::initialize() { #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" + + +// verify TritonGPU ops +mlir::LogicalResult +TritonGPUDialect::verifyOperationAttribute(mlir::Operation *op, + mlir::NamedAttribute attr) { + // TODO: fill this. + return success(); +} diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index b803bd30b..d110608be 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_public_tablegen_target(TritonGPUCombineIncGen) add_mlir_dialect_library(TritonGPUTransforms Combine.cpp Pipeline.cpp + Verifier.cpp TritonGPUConversion.cpp DEPENDS diff --git a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp new file mode 100644 index 000000000..139d3f3ae --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp @@ -0,0 +1,99 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +#include + +using namespace mlir; + + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUVerifier : public TritonGPUVerifierBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + // The idea is similar to mlir/lib/IR/Verifier.cpp + verifyImpl(m.getOperation()); + } + +private: + LogicalResult verifySingleOp(Operation *op) { + if (auto dotOp = llvm::dyn_cast(op)) { + Type aType = dotOp.a().getType(); + Type bType = dotOp.b().getType(); + Type cType = dotOp.c().getType(); + Type dType = dotOp.d().getType(); + for (auto it : llvm::zip(llvm::SmallVector{aType, bType}, + llvm::SmallVector{'a', 'b'})) { + Type type = std::get<0>(it); + char name = std::get<1>(it); + if (auto tensorType = type.dyn_cast()) { + Attribute encoding = tensorType.getEncoding(); + if (!encoding) + return dotOp.emitError() << name << " should have encoding"; + if (!encoding.isa()) + return dotOp.emitError() << name << " should be of shared layout"; + } else + return dotOp.emitError() << name << "'s type should be of RankedTensorType"; + } + + for (auto it : llvm::zip(llvm::SmallVector{cType, dType}, + llvm::SmallVector{'c', 'd'})) { + Type type = std::get<0>(it); + char name = std::get<1>(it); + if (auto tensorType = type.dyn_cast()) { + Attribute encoding = tensorType.getEncoding(); + if (!encoding) + return dotOp.emitError() << name << " should have encoding"; + if (!encoding.isa()) + return dotOp.emitError() << name << " should be of mma layout"; + } else + return dotOp.emitError() << name + << "'s type should be of RankedTensorType"; + } + + // signalPassFailure(); + } + if (auto loadOp = llvm::dyn_cast(op)) { + // TODO: fill this + } + if (auto storeOp = llvm::dyn_cast(op)) { + // TODO: fill this + } + if (auto gepOp = llvm::dyn_cast(op)) { + // TODO: fill this + } + // Triton builtin Ops + if (llvm::isa(op)) { + // TODO: fill this + } + if (auto atomicRmw = llvm::dyn_cast(op)) { + // TODO: fill this + } + if (auto atomicCas = llvm::dyn_cast(op)) { + // TODO: fill this + } + + // TODO: Arithmetic, SCF, TritonGPU ops + return success(); + } + + void verifyImpl(Operation *op) { + if(verifySingleOp(op).failed()) + signalPassFailure(); + + // verify that all child regions are ok + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Operation &childOp : block) + verifyImpl(&childOp); + } +}; + +std::unique_ptr triton::gpu::createTritonGPUVerifier() { + return std::make_unique(); +} diff --git a/python/src/triton.cc b/python/src/triton.cc index 85ed9566a..42d7fd312 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1347,6 +1347,9 @@ void init_triton_ir(py::module &&m) { .def("add_triton_gpu_combine_pass", [](mlir::PassManager &self) { self.addPass(mlir::triton::gpu::createCombineOpsPass()); }) + .def("add_triton_gpu_verifier_pass", [](mlir::PassManager &self) { + self.addPass(mlir::triton::gpu::createTritonGPUVerifier()); + }) ; } diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 1516b841e..87df9f89e 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -1320,8 +1320,8 @@ class JITFunction: pm.add_tritongpu_pipeline_pass(num_stages) pm.add_canonicalizer_pass() pm.add_triton_gpu_combine_pass() - pm.run(mod) - return mod + pm.add_triton_gpu_verifier_pass() + return pm.run(mod) def __getitem__(self, grid): diff --git a/rewrite-test/jit/matmul/matmul.py b/rewrite-test/jit/matmul/matmul.py index 9bd54fa81..821098f09 100644 --- a/rewrite-test/jit/matmul/matmul.py +++ b/rewrite-test/jit/matmul/matmul.py @@ -98,7 +98,8 @@ mod, ctx = matmul_kernel.compile_to_ttir( assert mod.verify() mod.dump() -mod = matmul_kernel.compile_ttir_to_llir(mod, ctx) +res = matmul_kernel.compile_ttir_to_llir(mod, ctx) assert mod.verify() +assert res mod.dump() diff --git a/rewrite-test/jit/vecadd.py b/rewrite-test/jit/vecadd.py index 758b99572..11a99517f 100644 --- a/rewrite-test/jit/vecadd.py +++ b/rewrite-test/jit/vecadd.py @@ -43,5 +43,5 @@ z = torch.empty_like(x) mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,)) assert mod.verify() mod.dump() -mod = add_kernel.compile_ttir_to_llir(mod, ctx) +add_kernel.compile_ttir_to_llir(mod, ctx) mod.dump()