[Triton-IR] Added type inference and verifier for Triton-IR operations (#767)

This commit is contained in:
Philippe Tillet
2022-10-11 18:16:41 -07:00
committed by GitHub
parent b6e5a231e5
commit 623c99609f
27 changed files with 494 additions and 348 deletions

View File

@@ -3,6 +3,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -288,8 +289,9 @@ Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) {
}
}
return parser.getChecked<BlockedEncodingAttr>(
auto ret = parser.getChecked<BlockedEncodingAttr>(
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
return ret;
}
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
@@ -346,27 +348,13 @@ void MmaEncodingAttr::print(AsmPrinter &printer) const {
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
NamedAttrList attrs;
if (parser.parseOptionalAttrDict(attrs).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned dim = 0;
Attribute parent;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "dim") {
if (parseUInt(parser, attr, dim, "dim").failed())
return {};
}
if (attr.getName() == "parent") {
if (parser.parseAttribute(parent).failed())
return {};
}
}
unsigned dim = attrs.get("dim").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
}
@@ -522,6 +510,35 @@ public:
}
};
struct TritonGPUInferLayoutInterface
: public triton::DialectInferLayoutInterface {
using DialectInferLayoutInterface::DialectInferLayoutInterface;
LogicalResult inferReduceOpEncoding(Attribute operandEncoding, int axis,
Attribute &resultEncoding) const {
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
operandEncoding);
return success();
}
LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
Attribute &resultEncoding) const {
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
if (!sliceEncoding) {
llvm::report_fatal_error(
"ExpandDimsOp operand encoding must be SliceEncodingAttr");
return failure();
}
if (sliceEncoding.getDim() != axis) {
llvm::report_fatal_error(
"Incompatible slice dimension for ExpandDimsOp operand");
return failure();
}
resultEncoding = sliceEncoding.getParent();
return success();
}
};
void TritonGPUDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
@@ -532,6 +549,7 @@ void TritonGPUDialect::initialize() {
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
>();
addInterfaces<TritonGPUOpAsmInterface>();
addInterfaces<TritonGPUInferLayoutInterface>();
}
//===----------------------------------------------------------------------===//
@@ -568,4 +586,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// TODO: fill this.
return success();
}
}

View File

@@ -8,7 +8,6 @@ add_mlir_dialect_library(TritonGPUTransforms
Combine.cpp
Pipeline.cpp
Swizzle.cpp
Verifier.cpp
TritonGPUConversion.cpp
DEPENDS

View File

@@ -1,106 +0,0 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <memory>
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUVerifier : public TritonGPUVerifierBase<TritonGPUVerifier> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
// The idea is similar to mlir/lib/IR/Verifier.cpp
verifyImpl(m.getOperation());
}
private:
LogicalResult verifySingleOp(Operation *op) {
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(op)) {
Type aType = dotOp.a().getType();
Type bType = dotOp.b().getType();
Type cType = dotOp.c().getType();
Type dType = dotOp.d().getType();
for (auto it : llvm::zip(llvm::SmallVector<Type>{aType, bType},
llvm::SmallVector<char>{'a', 'b'})) {
Type type = std::get<0>(it);
char name = std::get<1>(it);
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
if (!encoding)
return dotOp.emitError() << name << " should have encoding";
if (!encoding.isa<triton::gpu::SharedEncodingAttr>())
return dotOp.emitError() << name << " should be of shared layout";
} else
return dotOp.emitError()
<< name << "'s type should be of RankedTensorType";
}
Attribute cLayout;
for (auto it : llvm::zip(llvm::SmallVector<Type>{cType, dType},
llvm::SmallVector<char>{'c', 'd'})) {
Type type = std::get<0>(it);
char name = std::get<1>(it);
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
if (!encoding)
return dotOp.emitError() << name << " should have encoding";
if (!encoding.isa<triton::gpu::MmaEncodingAttr>() &&
!encoding.isa<triton::gpu::BlockedEncodingAttr>())
return dotOp.emitError()
<< name << " should be of distributed layout";
if (name == 'c')
cLayout = encoding;
else if (encoding != cLayout)
return dotOp.emitError() << "d & c should have the same layout";
} else
return dotOp.emitError()
<< name << "'s type should be of RankedTensorType";
}
// signalPassFailure();
}
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
// TODO: fill this
}
if (auto storeOp = llvm::dyn_cast<triton::StoreOp>(op)) {
// TODO: fill this
}
if (auto addptrOp = llvm::dyn_cast<triton::AddPtrOp>(op)) {
// TODO: fill this
}
// Triton builtin Ops
if (llvm::isa<triton::GetProgramIdOp, triton::GetNumProgramsOp,
triton::MakeRangeOp>(op)) {
// TODO: fill this
}
if (auto atomicRmw = llvm::dyn_cast<triton::AtomicRMWOp>(op)) {
// TODO: fill this
}
if (auto atomicCas = llvm::dyn_cast<triton::AtomicCASOp>(op)) {
// TODO: fill this
}
// TODO: Arithmetic, SCF, TritonGPU ops
return success();
}
void verifyImpl(Operation *op) {
if (verifySingleOp(op).failed())
signalPassFailure();
// verify that all child regions are ok
for (Region &region : op->getRegions())
for (Block &block : region)
for (Operation &childOp : block)
verifyImpl(&childOp);
}
};
std::unique_ptr<Pass> mlir::createTritonGPUVerifier() {
return std::make_unique<TritonGPUVerifier>();
}