[Triton-IR] Added type inference and verifier for Triton-IR operations (#767)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
@@ -288,8 +289,9 @@ Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<BlockedEncodingAttr>(
|
||||
auto ret = parser.getChecked<BlockedEncodingAttr>(
|
||||
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
@@ -346,27 +348,13 @@ void MmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
DictionaryAttr dict;
|
||||
if (parser.parseAttribute(dict).failed())
|
||||
NamedAttrList attrs;
|
||||
if (parser.parseOptionalAttrDict(attrs).failed())
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
unsigned dim = 0;
|
||||
Attribute parent;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "dim") {
|
||||
if (parseUInt(parser, attr, dim, "dim").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "parent") {
|
||||
if (parser.parseAttribute(parent).failed())
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
unsigned dim = attrs.get("dim").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
|
||||
}
|
||||
|
||||
@@ -522,6 +510,35 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonGPUInferLayoutInterface
|
||||
: public triton::DialectInferLayoutInterface {
|
||||
using DialectInferLayoutInterface::DialectInferLayoutInterface;
|
||||
|
||||
LogicalResult inferReduceOpEncoding(Attribute operandEncoding, int axis,
|
||||
Attribute &resultEncoding) const {
|
||||
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
|
||||
operandEncoding);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
|
||||
Attribute &resultEncoding) const {
|
||||
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
|
||||
if (!sliceEncoding) {
|
||||
llvm::report_fatal_error(
|
||||
"ExpandDimsOp operand encoding must be SliceEncodingAttr");
|
||||
return failure();
|
||||
}
|
||||
if (sliceEncoding.getDim() != axis) {
|
||||
llvm::report_fatal_error(
|
||||
"Incompatible slice dimension for ExpandDimsOp operand");
|
||||
return failure();
|
||||
}
|
||||
resultEncoding = sliceEncoding.getParent();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
@@ -532,6 +549,7 @@ void TritonGPUDialect::initialize() {
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
>();
|
||||
addInterfaces<TritonGPUOpAsmInterface>();
|
||||
addInterfaces<TritonGPUInferLayoutInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -568,4 +586,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
// TODO: fill this.
|
||||
return success();
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user