TritonGPU verifier

This commit is contained in:
Yan Da
2022-05-24 19:48:56 +08:00
parent 36c45ec687
commit a2c9f919a8
10 changed files with 131 additions and 4 deletions

View File

@@ -8,6 +8,8 @@ def TritonGPU_Dialect : Dialect {
let cppNamespace = "::mlir::triton::gpu";
let hasOperationAttrVerify = 1;
let description = [{
Triton GPU Dialect.
}];

View File

@@ -9,6 +9,8 @@ std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages);
namespace triton {
namespace gpu {
std::unique_ptr<Pass> createCombineOpsPass();
std::unique_ptr<Pass> createTritonGPUVerifier();
}
}

View File

@@ -51,4 +51,14 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
"mlir::triton::TritonDialect"];
}
def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> {
let summary = "verify TritonGPU IR";
let description = [{}];
let constructor = "mlir::triton::gpu::createTritonGPUVerifier";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
#endif

View File

@@ -62,3 +62,12 @@ void TritonGPUDialect::initialize() {
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
// verify TritonGPU ops
mlir::LogicalResult
TritonGPUDialect::verifyOperationAttribute(mlir::Operation *op,
mlir::NamedAttribute attr) {
// TODO: fill this.
return success();
}

View File

@@ -5,6 +5,7 @@ add_public_tablegen_target(TritonGPUCombineIncGen)
add_mlir_dialect_library(TritonGPUTransforms
Combine.cpp
Pipeline.cpp
Verifier.cpp
TritonGPUConversion.cpp
DEPENDS

View File

@@ -0,0 +1,99 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <memory>
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUVerifier : public TritonGPUVerifierBase<TritonGPUVerifier> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
// The idea is similar to mlir/lib/IR/Verifier.cpp
verifyImpl(m.getOperation());
}
private:
LogicalResult verifySingleOp(Operation *op) {
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(op)) {
Type aType = dotOp.a().getType();
Type bType = dotOp.b().getType();
Type cType = dotOp.c().getType();
Type dType = dotOp.d().getType();
for (auto it : llvm::zip(llvm::SmallVector<Type>{aType, bType},
llvm::SmallVector<char>{'a', 'b'})) {
Type type = std::get<0>(it);
char name = std::get<1>(it);
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
if (!encoding)
return dotOp.emitError() << name << " should have encoding";
if (!encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
return dotOp.emitError() << name << " should be of shared layout";
} else
return dotOp.emitError() << name << "'s type should be of RankedTensorType";
}
for (auto it : llvm::zip(llvm::SmallVector<Type>{cType, dType},
llvm::SmallVector<char>{'c', 'd'})) {
Type type = std::get<0>(it);
char name = std::get<1>(it);
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
if (!encoding)
return dotOp.emitError() << name << " should have encoding";
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>())
return dotOp.emitError() << name << " should be of mma layout";
} else
return dotOp.emitError() << name
<< "'s type should be of RankedTensorType";
}
// signalPassFailure();
}
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
// TODO: fill this
}
if (auto storeOp = llvm::dyn_cast<triton::StoreOp>(op)) {
// TODO: fill this
}
if (auto gepOp = llvm::dyn_cast<triton::GEPOp>(op)) {
// TODO: fill this
}
// Triton builtin Ops
if (llvm::isa<triton::GetProgramIdOp, triton::GetNumProgramsOp,
triton::MakeRangeOp>(op)) {
// TODO: fill this
}
if (auto atomicRmw = llvm::dyn_cast<triton::AtomicRMWOp>(op)) {
// TODO: fill this
}
if (auto atomicCas = llvm::dyn_cast<triton::AtomicCASOp>(op)) {
// TODO: fill this
}
// TODO: Arithmetic, SCF, TritonGPU ops
return success();
}
void verifyImpl(Operation *op) {
if(verifySingleOp(op).failed())
signalPassFailure();
// verify that all child regions are ok
for (Region &region : op->getRegions())
for (Block &block : region)
for (Operation &childOp : block)
verifyImpl(&childOp);
}
};
std::unique_ptr<Pass> triton::gpu::createTritonGPUVerifier() {
return std::make_unique<TritonGPUVerifier>();
}

View File

@@ -1347,6 +1347,9 @@ void init_triton_ir(py::module &&m) {
.def("add_triton_gpu_combine_pass", [](mlir::PassManager &self) {
self.addPass(mlir::triton::gpu::createCombineOpsPass());
})
.def("add_triton_gpu_verifier_pass", [](mlir::PassManager &self) {
self.addPass(mlir::triton::gpu::createTritonGPUVerifier());
})
;
}

View File

@@ -1320,8 +1320,8 @@ class JITFunction:
pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_canonicalizer_pass()
pm.add_triton_gpu_combine_pass()
pm.run(mod)
return mod
pm.add_triton_gpu_verifier_pass()
return pm.run(mod)
def __getitem__(self, grid):

View File

@@ -98,7 +98,8 @@ mod, ctx = matmul_kernel.compile_to_ttir(
assert mod.verify()
mod.dump()
mod = matmul_kernel.compile_ttir_to_llir(mod, ctx)
res = matmul_kernel.compile_ttir_to_llir(mod, ctx)
assert mod.verify()
assert res
mod.dump()

View File

@@ -43,5 +43,5 @@ z = torch.empty_like(x)
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,))
assert mod.verify()
mod.dump()
mod = add_kernel.compile_ttir_to_llir(mod, ctx)
add_kernel.compile_ttir_to_llir(mod, ctx)
mod.dump()