[Triton-IR] Added type inference and verifier for Triton-IR operations (#767)
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
add_mlir_dialect_library(TritonIR
|
||||
Interfaces.cpp
|
||||
Dialect.cpp
|
||||
Ops.cpp
|
||||
Types.cpp
|
||||
Traits.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonTableGen
|
||||
|
@@ -1,6 +1,7 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
0
lib/Dialect/Triton/IR/Interfaces.cpp
Normal file
0
lib/Dialect/Triton/IR/Interfaces.cpp
Normal file
@@ -160,18 +160,88 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
}
|
||||
|
||||
//-- DotOp --
|
||||
mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// type is the same as the accumulator
|
||||
auto accTy = operands[2].getType().cast<RankedTensorType>();
|
||||
inferredReturnTypes.push_back(accTy);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//-- ReduceOp --
|
||||
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// infer shape
|
||||
Value arg = operands[0];
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto retShape = argTy.getShape().vec();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
retShape.erase(retShape.begin() + axis);
|
||||
// infer encoding
|
||||
Attribute argEncoding = argTy.getEncoding();
|
||||
Attribute retEncoding;
|
||||
if (argEncoding) {
|
||||
Dialect &dialect = argEncoding.getDialect();
|
||||
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||
if (inferLayoutInterface
|
||||
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
|
||||
.failed()) {
|
||||
llvm::report_fatal_error("failed to infer layout for ReduceOp");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// create type
|
||||
auto argEltTy = argTy.getElementType();
|
||||
inferredReturnTypes.push_back(
|
||||
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//-- SplatOp --
|
||||
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||
return ret;
|
||||
}
|
||||
|
||||
//-- ExpandDimsOp --
|
||||
mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// infer shape
|
||||
auto arg = operands[0];
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto retShape = argTy.getShape().vec();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
retShape.insert(retShape.begin() + axis, 1);
|
||||
// infer encoding
|
||||
Attribute argEncoding = argTy.getEncoding();
|
||||
Attribute retEncoding;
|
||||
if (argEncoding) {
|
||||
Dialect &dialect = argEncoding.getDialect();
|
||||
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||
if (inferLayoutInterface
|
||||
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding)
|
||||
.failed()) {
|
||||
llvm::report_fatal_error("failed to infer layout for ExpandDimsOp");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// create type
|
||||
auto argEltTy = argTy.getElementType();
|
||||
inferredReturnTypes.push_back(
|
||||
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//-- BroadcastOp --
|
||||
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
|
69
lib/Dialect/Triton/IR/Traits.cpp
Normal file
69
lib/Dialect/Triton/IR/Traits.cpp
Normal file
@@ -0,0 +1,69 @@
|
||||
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||
|
||||
static mlir::LogicalResult verifySameEncoding(mlir::Type tyA, mlir::Type tyB) {
|
||||
using namespace mlir;
|
||||
auto encA = tyA.dyn_cast<RankedTensorType>();
|
||||
auto encB = tyA.dyn_cast<RankedTensorType>();
|
||||
if (!encA || !encB)
|
||||
return success();
|
||||
return encA.getEncoding() == encB.getEncoding() ? success() : failure();
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) {
|
||||
if (failed(verifyAtLeastNOperands(op, 1)) ||
|
||||
failed(verifyAtLeastNResults(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto resultType : op->getResultTypes())
|
||||
if (failed(verifySameEncoding(resultType, type)))
|
||||
return op->emitOpError()
|
||||
<< "requires the same shape for all operands and results";
|
||||
return verifySameOperandsEncoding(op);
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::OpTrait::impl::verifySameOperandsEncoding(Operation *op) {
|
||||
if (failed(verifyAtLeastNOperands(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
|
||||
if (failed(verifySameEncoding(opType, type)))
|
||||
return op->emitOpError() << "requires the same encoding for all operands";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
|
||||
for (auto opType : op->getOperandTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > maxTensorNumElements)
|
||||
return op->emitError("Maximum allowed number of elements is ")
|
||||
<< maxTensorNumElements << ", but " << *op
|
||||
<< " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule";
|
||||
}
|
||||
}
|
||||
for (auto opType : op->getResultTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > maxTensorNumElements)
|
||||
return op->emitError("Maximum allowed number of elements is ")
|
||||
<< maxTensorNumElements << ", but " << *op
|
||||
<< " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
@@ -3,6 +3,7 @@
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
@@ -288,8 +289,9 @@ Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<BlockedEncodingAttr>(
|
||||
auto ret = parser.getChecked<BlockedEncodingAttr>(
|
||||
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
@@ -346,27 +348,13 @@ void MmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
DictionaryAttr dict;
|
||||
if (parser.parseAttribute(dict).failed())
|
||||
NamedAttrList attrs;
|
||||
if (parser.parseOptionalAttrDict(attrs).failed())
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
unsigned dim = 0;
|
||||
Attribute parent;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "dim") {
|
||||
if (parseUInt(parser, attr, dim, "dim").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "parent") {
|
||||
if (parser.parseAttribute(parent).failed())
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
unsigned dim = attrs.get("dim").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
|
||||
}
|
||||
|
||||
@@ -522,6 +510,35 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonGPUInferLayoutInterface
|
||||
: public triton::DialectInferLayoutInterface {
|
||||
using DialectInferLayoutInterface::DialectInferLayoutInterface;
|
||||
|
||||
LogicalResult inferReduceOpEncoding(Attribute operandEncoding, int axis,
|
||||
Attribute &resultEncoding) const {
|
||||
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
|
||||
operandEncoding);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, int axis,
|
||||
Attribute &resultEncoding) const {
|
||||
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
|
||||
if (!sliceEncoding) {
|
||||
llvm::report_fatal_error(
|
||||
"ExpandDimsOp operand encoding must be SliceEncodingAttr");
|
||||
return failure();
|
||||
}
|
||||
if (sliceEncoding.getDim() != axis) {
|
||||
llvm::report_fatal_error(
|
||||
"Incompatible slice dimension for ExpandDimsOp operand");
|
||||
return failure();
|
||||
}
|
||||
resultEncoding = sliceEncoding.getParent();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
@@ -532,6 +549,7 @@ void TritonGPUDialect::initialize() {
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
>();
|
||||
addInterfaces<TritonGPUOpAsmInterface>();
|
||||
addInterfaces<TritonGPUInferLayoutInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -568,4 +586,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
// TODO: fill this.
|
||||
return success();
|
||||
}
|
||||
}
|
@@ -8,7 +8,6 @@ add_mlir_dialect_library(TritonGPUTransforms
|
||||
Combine.cpp
|
||||
Pipeline.cpp
|
||||
Swizzle.cpp
|
||||
Verifier.cpp
|
||||
TritonGPUConversion.cpp
|
||||
|
||||
DEPENDS
|
||||
|
@@ -1,106 +0,0 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class TritonGPUVerifier : public TritonGPUVerifierBase<TritonGPUVerifier> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
// The idea is similar to mlir/lib/IR/Verifier.cpp
|
||||
verifyImpl(m.getOperation());
|
||||
}
|
||||
|
||||
private:
|
||||
LogicalResult verifySingleOp(Operation *op) {
|
||||
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(op)) {
|
||||
Type aType = dotOp.a().getType();
|
||||
Type bType = dotOp.b().getType();
|
||||
Type cType = dotOp.c().getType();
|
||||
Type dType = dotOp.d().getType();
|
||||
for (auto it : llvm::zip(llvm::SmallVector<Type>{aType, bType},
|
||||
llvm::SmallVector<char>{'a', 'b'})) {
|
||||
Type type = std::get<0>(it);
|
||||
char name = std::get<1>(it);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding)
|
||||
return dotOp.emitError() << name << " should have encoding";
|
||||
if (!encoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||
return dotOp.emitError() << name << " should be of shared layout";
|
||||
} else
|
||||
return dotOp.emitError()
|
||||
<< name << "'s type should be of RankedTensorType";
|
||||
}
|
||||
|
||||
Attribute cLayout;
|
||||
for (auto it : llvm::zip(llvm::SmallVector<Type>{cType, dType},
|
||||
llvm::SmallVector<char>{'c', 'd'})) {
|
||||
Type type = std::get<0>(it);
|
||||
char name = std::get<1>(it);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding)
|
||||
return dotOp.emitError() << name << " should have encoding";
|
||||
if (!encoding.isa<triton::gpu::MmaEncodingAttr>() &&
|
||||
!encoding.isa<triton::gpu::BlockedEncodingAttr>())
|
||||
return dotOp.emitError()
|
||||
<< name << " should be of distributed layout";
|
||||
if (name == 'c')
|
||||
cLayout = encoding;
|
||||
else if (encoding != cLayout)
|
||||
return dotOp.emitError() << "d & c should have the same layout";
|
||||
} else
|
||||
return dotOp.emitError()
|
||||
<< name << "'s type should be of RankedTensorType";
|
||||
}
|
||||
|
||||
// signalPassFailure();
|
||||
}
|
||||
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
if (auto storeOp = llvm::dyn_cast<triton::StoreOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
if (auto addptrOp = llvm::dyn_cast<triton::AddPtrOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
// Triton builtin Ops
|
||||
if (llvm::isa<triton::GetProgramIdOp, triton::GetNumProgramsOp,
|
||||
triton::MakeRangeOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
if (auto atomicRmw = llvm::dyn_cast<triton::AtomicRMWOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
if (auto atomicCas = llvm::dyn_cast<triton::AtomicCASOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
|
||||
// TODO: Arithmetic, SCF, TritonGPU ops
|
||||
return success();
|
||||
}
|
||||
|
||||
void verifyImpl(Operation *op) {
|
||||
if (verifySingleOp(op).failed())
|
||||
signalPassFailure();
|
||||
|
||||
// verify that all child regions are ok
|
||||
for (Region ®ion : op->getRegions())
|
||||
for (Block &block : region)
|
||||
for (Operation &childOp : block)
|
||||
verifyImpl(&childOp);
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUVerifier() {
|
||||
return std::make_unique<TritonGPUVerifier>();
|
||||
}
|
Reference in New Issue
Block a user