#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include using namespace mlir; #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" class TritonGPUVerifier : public TritonGPUVerifierBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); // The idea is similar to mlir/lib/IR/Verifier.cpp verifyImpl(m.getOperation()); } private: LogicalResult verifySingleOp(Operation *op) { if (auto dotOp = llvm::dyn_cast(op)) { Type aType = dotOp.a().getType(); Type bType = dotOp.b().getType(); Type cType = dotOp.c().getType(); Type dType = dotOp.d().getType(); for (auto it : llvm::zip(llvm::SmallVector{aType, bType}, llvm::SmallVector{'a', 'b'})) { Type type = std::get<0>(it); char name = std::get<1>(it); if (auto tensorType = type.dyn_cast()) { Attribute encoding = tensorType.getEncoding(); if (!encoding) return dotOp.emitError() << name << " should have encoding"; if (!encoding.isa()) return dotOp.emitError() << name << " should be of shared layout"; } else return dotOp.emitError() << name << "'s type should be of RankedTensorType"; } Attribute cLayout; for (auto it : llvm::zip(llvm::SmallVector{cType, dType}, llvm::SmallVector{'c', 'd'})) { Type type = std::get<0>(it); char name = std::get<1>(it); if (auto tensorType = type.dyn_cast()) { Attribute encoding = tensorType.getEncoding(); if (!encoding) return dotOp.emitError() << name << " should have encoding"; if (!encoding.isa() && !encoding.isa()) return dotOp.emitError() << name << " should be of distributed layout"; if (name == 'c') cLayout = encoding; else if (encoding != cLayout) return dotOp.emitError() << "d & c should have the same layout"; } else return dotOp.emitError() << name << "'s type should be of RankedTensorType"; } // signalPassFailure(); } if (auto loadOp = llvm::dyn_cast(op)) { // TODO: fill this } if (auto storeOp = llvm::dyn_cast(op)) { // TODO: fill this } if (auto gepOp = llvm::dyn_cast(op)) { // TODO: fill this } // Triton builtin Ops if (llvm::isa(op)) { // TODO: fill this } if (auto atomicRmw = llvm::dyn_cast(op)) { // TODO: fill this } if (auto atomicCas = llvm::dyn_cast(op)) { // TODO: fill this } // TODO: Arithmetic, SCF, TritonGPU ops return success(); } void verifyImpl(Operation *op) { if(verifySingleOp(op).failed()) signalPassFailure(); // verify that all child regions are ok for (Region ®ion : op->getRegions()) for (Block &block : region) for (Operation &childOp : block) verifyImpl(&childOp); } }; std::unique_ptr triton::gpu::createTritonGPUVerifier() { return std::make_unique(); }