[Triton-IR] Added type inference and verifier for Triton-IR operations (#767)

This commit is contained in:
Philippe Tillet
2022-10-11 18:16:41 -07:00
committed by GitHub
parent b6e5a231e5
commit 623c99609f
27 changed files with 494 additions and 348 deletions

View File

@@ -1,7 +1,9 @@
add_mlir_dialect_library(TritonIR
Interfaces.cpp
Dialect.cpp
Ops.cpp
Types.cpp
Traits.cpp
DEPENDS
TritonTableGen

View File

@@ -1,6 +1,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"

View File

View File

@@ -160,18 +160,88 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
}
//-- DotOp --
mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// type is the same as the accumulator
auto accTy = operands[2].getType().cast<RankedTensorType>();
inferredReturnTypes.push_back(accTy);
return mlir::success();
}
//-- ReduceOp --
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// infer shape
Value arg = operands[0];
auto argTy = arg.getType().cast<RankedTensorType>();
auto retShape = argTy.getShape().vec();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
retShape.erase(retShape.begin() + axis);
// infer encoding
Attribute argEncoding = argTy.getEncoding();
Attribute retEncoding;
if (argEncoding) {
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
.failed()) {
llvm::report_fatal_error("failed to infer layout for ReduceOp");
return mlir::failure();
}
}
// create type
auto argEltTy = argTy.getElementType();
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, argEltTy, retEncoding));
return mlir::success();
}
//-- SplatOp --
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
return ret;
}
//-- ExpandDimsOp --
mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// infer shape
auto arg = operands[0];
auto argTy = arg.getType().cast<RankedTensorType>();
auto retShape = argTy.getShape().vec();
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
retShape.insert(retShape.begin() + axis, 1);
// infer encoding
Attribute argEncoding = argTy.getEncoding();
Attribute retEncoding;
if (argEncoding) {
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding)
.failed()) {
llvm::report_fatal_error("failed to infer layout for ExpandDimsOp");
return mlir::failure();
}
}
// create type
auto argEltTy = argTy.getElementType();
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, argEltTy, retEncoding));
return mlir::success();
}
//-- BroadcastOp --
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();

View File

@@ -0,0 +1,69 @@
#include "triton/Dialect/Triton/IR/Traits.h"
static mlir::LogicalResult verifySameEncoding(mlir::Type tyA, mlir::Type tyB) {
using namespace mlir;
auto encA = tyA.dyn_cast<RankedTensorType>();
auto encB = tyA.dyn_cast<RankedTensorType>();
if (!encA || !encB)
return success();
return encA.getEncoding() == encB.getEncoding() ? success() : failure();
}
mlir::LogicalResult
mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) {
if (failed(verifyAtLeastNOperands(op, 1)) ||
failed(verifyAtLeastNResults(op, 1)))
return failure();
auto type = op->getOperand(0).getType();
for (auto resultType : op->getResultTypes())
if (failed(verifySameEncoding(resultType, type)))
return op->emitOpError()
<< "requires the same shape for all operands and results";
return verifySameOperandsEncoding(op);
}
mlir::LogicalResult
mlir::OpTrait::impl::verifySameOperandsEncoding(Operation *op) {
if (failed(verifyAtLeastNOperands(op, 1)))
return failure();
auto type = op->getOperand(0).getType();
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
if (failed(verifySameEncoding(opType, type)))
return op->emitOpError() << "requires the same encoding for all operands";
return success();
}
mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
for (auto opType : op->getOperandTypes()) {
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
int64_t numElements = 1;
for (int64_t s : tensorType.getShape())
numElements *= s;
if (numElements > maxTensorNumElements)
return op->emitError("Maximum allowed number of elements is ")
<< maxTensorNumElements << ", but " << *op
<< " has more than that";
if ((numElements & (numElements - 1)) != 0)
return op->emitError("Number of elements must be power-of-two, but ")
<< *op << " doesn't follow the rule";
}
}
for (auto opType : op->getResultTypes()) {
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
int64_t numElements = 1;
for (int64_t s : tensorType.getShape())
numElements *= s;
if (numElements > maxTensorNumElements)
return op->emitError("Maximum allowed number of elements is ")
<< maxTensorNumElements << ", but " << *op
<< " has more than that";
if ((numElements & (numElements - 1)) != 0)
return op->emitError("Number of elements must be power-of-two, but ")
<< *op << " doesn't follow the rule";
}
}
return success();
}