Merge triton-mlir
branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
20
lib/Dialect/Triton/IR/CMakeLists.txt
Normal file
20
lib/Dialect/Triton/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
add_mlir_dialect_library(TritonIR
|
||||
Interfaces.cpp
|
||||
Dialect.cpp
|
||||
Ops.cpp
|
||||
Types.cpp
|
||||
Traits.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonTableGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRArithmetic
|
||||
MLIRSCF
|
||||
|
||||
# Since LLVM 15
|
||||
# MLIRFunc
|
||||
# else
|
||||
MLIRStandard
|
||||
)
|
51
lib/Dialect/Triton/IR/Dialect.cpp
Normal file
51
lib/Dialect/Triton/IR/Dialect.cpp
Normal file
@@ -0,0 +1,51 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TritonDialect Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct TritonInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||
BlockAndValueMapping &valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void TritonDialect::initialize() {
|
||||
registerTypes();
|
||||
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
|
||||
>();
|
||||
|
||||
// We can also add interface here.
|
||||
addInterfaces<TritonInlinerInterface>();
|
||||
}
|
||||
|
||||
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
return builder.create<arith::ConstantOp>(loc, type, value);
|
||||
}
|
0
lib/Dialect/Triton/IR/Interfaces.cpp
Normal file
0
lib/Dialect/Triton/IR/Interfaces.cpp
Normal file
346
lib/Dialect/Triton/IR/Ops.cpp
Normal file
346
lib/Dialect/Triton/IR/Ops.cpp
Normal file
@@ -0,0 +1,346 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// Type inference
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return i1Type;
|
||||
}
|
||||
|
||||
static Type getI32SameShape(Type type) {
|
||||
auto i32Type = IntegerType::get(type.getContext(), 32);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i32Type,
|
||||
tensorType.getEncoding());
|
||||
return i32Type;
|
||||
}
|
||||
|
||||
static Type getPointerTypeSameShape(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
Type elementType = tensorType.getElementType();
|
||||
auto shape = tensorType.getShape();
|
||||
PointerType ptrType = PointerType::get(elementType, 1);
|
||||
return RankedTensorType::get(shape, ptrType, tensorType.getEncoding());
|
||||
} else {
|
||||
return PointerType::get(type, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Parser & printer for assembly forms
|
||||
ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
Type resultTypes[1];
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseCustomTypeWithFallback(resultTypes[0]))
|
||||
return failure();
|
||||
|
||||
result.addTypes(resultTypes);
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(getPointerTypeSameShape(resultTypes[0])); // ptr
|
||||
int hasMask = 0, hasOther = 0;
|
||||
if (allOperands.size() >= 2) {
|
||||
operandTypes.push_back(getI1SameShape(resultTypes[0])); // mask
|
||||
hasMask = 1;
|
||||
}
|
||||
if (allOperands.size() >= 3) {
|
||||
operandTypes.push_back(resultTypes[0]); // other
|
||||
hasOther = 1;
|
||||
}
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
// Deduce operand_segment_sizes from the number of the operands.
|
||||
auto operand_segment_sizesAttrName =
|
||||
LoadOp::operand_segment_sizesAttrName(result.name);
|
||||
result.addAttribute(
|
||||
operand_segment_sizesAttrName,
|
||||
parser.getBuilder().getI32VectorAttr({1, hasMask, hasOther}));
|
||||
return success();
|
||||
}
|
||||
|
||||
void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
|
||||
printer << " ";
|
||||
printer << loadOp.getOperation()->getOperands();
|
||||
// "operand_segment_sizes" can be deduced, so we don't print it.
|
||||
printer.printOptionalAttrDict(loadOp->getAttrs(),
|
||||
{loadOp.operand_segment_sizesAttrName()});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(loadOp.result().getType());
|
||||
}
|
||||
|
||||
ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> allOperands;
|
||||
Type valueType;
|
||||
SMLoc allOperandLoc = parser.getCurrentLocation();
|
||||
if (parser.parseOperandList(allOperands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
||||
parser.parseCustomTypeWithFallback(valueType))
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> operandTypes;
|
||||
operandTypes.push_back(getPointerTypeSameShape(valueType)); // ptr
|
||||
operandTypes.push_back(valueType); // value
|
||||
if (allOperands.size() >= 3)
|
||||
operandTypes.push_back(getI1SameShape(valueType)); // mask
|
||||
|
||||
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
|
||||
printer << " ";
|
||||
printer << storeOp.getOperation()->getOperands();
|
||||
printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{});
|
||||
printer << " : ";
|
||||
printer.printStrippedAttrOrType(storeOp.value().getType());
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/Ops.cpp.inc"
|
||||
|
||||
// enum attribute definitions
|
||||
#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
//-- FpToFpOp --
|
||||
bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
|
||||
::mlir::TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
auto srcEltType = inputs.front();
|
||||
auto dstEltType = outputs.front();
|
||||
auto srcTensorType = srcEltType.dyn_cast<mlir::RankedTensorType>();
|
||||
auto dstTensorType = dstEltType.dyn_cast<mlir::RankedTensorType>();
|
||||
if (srcTensorType && dstTensorType) {
|
||||
srcEltType = srcTensorType.getElementType();
|
||||
dstEltType = dstTensorType.getElementType();
|
||||
}
|
||||
// Check whether fp8 <=> fp16, bf16, f32, f64
|
||||
// Make `srcEltType` always the fp8 side
|
||||
if (dstEltType.dyn_cast<mlir::triton::Float8Type>())
|
||||
std::swap(srcEltType, dstEltType);
|
||||
if (!srcEltType.dyn_cast<mlir::triton::Float8Type>())
|
||||
return false;
|
||||
return dstEltType.isF16() || dstEltType.isBF16() || dstEltType.isF32() ||
|
||||
dstEltType.isF64();
|
||||
}
|
||||
|
||||
//-- StoreOp --
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value) {
|
||||
StoreOp::build(builder, state, ptr, value, mlir::Value());
|
||||
}
|
||||
|
||||
//-- LoadOp --
|
||||
static Type getLoadOpResultType(::mlir::OpBuilder &builder, Type ptrType) {
|
||||
auto ptrTensorType = ptrType.dyn_cast<RankedTensorType>();
|
||||
if (!ptrTensorType)
|
||||
return ptrType.cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrTensorType.getShape();
|
||||
Type elementType =
|
||||
ptrTensorType.getElementType().cast<PointerType>().getPointeeType();
|
||||
return RankedTensorType::get(shape, elementType);
|
||||
}
|
||||
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
LoadOp::build(builder, state, ptr, mlir::Value(), mlir::Value(), cache, evict,
|
||||
isVolatile);
|
||||
}
|
||||
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value mask,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
LoadOp::build(builder, state, ptr, mask, mlir::Value(), cache, evict,
|
||||
isVolatile);
|
||||
}
|
||||
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
Type resultType = getLoadOpResultType(builder, ptr.getType());
|
||||
|
||||
state.addOperands(ptr);
|
||||
if (mask) {
|
||||
state.addOperands(mask);
|
||||
if (other) {
|
||||
state.addOperands(other);
|
||||
}
|
||||
}
|
||||
state.addAttribute(
|
||||
operand_segment_sizesAttrName(state.name),
|
||||
builder.getI32VectorAttr({1, (mask ? 1 : 0), (other ? 1 : 0)}));
|
||||
state.addAttribute(
|
||||
cacheAttrName(state.name),
|
||||
::mlir::triton::CacheModifierAttr::get(builder.getContext(), cache));
|
||||
state.addAttribute(
|
||||
evictAttrName(state.name),
|
||||
::mlir::triton::EvictionPolicyAttr::get(builder.getContext(), evict));
|
||||
state.addAttribute(isVolatileAttrName(state.name),
|
||||
builder.getBoolAttr(isVolatile));
|
||||
state.addTypes({resultType});
|
||||
}
|
||||
|
||||
//-- DotOp --
|
||||
mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// type is the same as the accumulator
|
||||
auto accTy = operands[2].getType().cast<RankedTensorType>();
|
||||
inferredReturnTypes.push_back(accTy);
|
||||
|
||||
// verify encodings
|
||||
auto aEnc = operands[0].getType().cast<RankedTensorType>().getEncoding();
|
||||
auto bEnc = operands[1].getType().cast<RankedTensorType>().getEncoding();
|
||||
auto retEnc = accTy.getEncoding();
|
||||
if (aEnc) {
|
||||
assert(bEnc);
|
||||
Dialect &dialect = aEnc.getDialect();
|
||||
auto interface = dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||
if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed())
|
||||
return mlir::failure();
|
||||
if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed())
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//-- ReduceOp --
|
||||
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// infer shape
|
||||
Value arg = operands[0];
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto argEltTy = argTy.getElementType();
|
||||
auto i32Ty = IntegerType::get(argEltTy.getContext(), 32);
|
||||
auto redOp =
|
||||
attributes.get("redOp").cast<mlir::triton::RedOpAttr>().getValue();
|
||||
bool withIndex = mlir::triton::ReduceOp::withIndex(redOp);
|
||||
auto retEltTy = withIndex ? i32Ty : argEltTy;
|
||||
auto retShape = argTy.getShape().vec();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
retShape.erase(retShape.begin() + axis);
|
||||
if (retShape.empty()) {
|
||||
// 0d-tensor -> scalar
|
||||
inferredReturnTypes.push_back(retEltTy);
|
||||
} else {
|
||||
// nd-tensor where n >= 1
|
||||
// infer encoding
|
||||
Attribute argEncoding = argTy.getEncoding();
|
||||
Attribute retEncoding;
|
||||
if (argEncoding) {
|
||||
Dialect &dialect = argEncoding.getDialect();
|
||||
auto inferLayoutInterface =
|
||||
dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||
if (inferLayoutInterface
|
||||
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
|
||||
.failed()) {
|
||||
llvm::report_fatal_error("failed to infer layout for ReduceOp");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// create type
|
||||
inferredReturnTypes.push_back(
|
||||
RankedTensorType::get(retShape, retEltTy, retEncoding));
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) {
|
||||
return redOp == mlir::triton::RedOp::ARGMIN ||
|
||||
redOp == mlir::triton::RedOp::ARGMAX ||
|
||||
redOp == mlir::triton::RedOp::ARGUMIN ||
|
||||
redOp == mlir::triton::RedOp::ARGUMAX ||
|
||||
redOp == mlir::triton::RedOp::ARGFMIN ||
|
||||
redOp == mlir::triton::RedOp::ARGFMAX;
|
||||
}
|
||||
|
||||
//-- SplatOp --
|
||||
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||
return ret;
|
||||
}
|
||||
|
||||
//-- ExpandDimsOp --
|
||||
mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> loc, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// infer shape
|
||||
auto arg = operands[0];
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
auto retShape = argTy.getShape().vec();
|
||||
int axis = attributes.get("axis").cast<IntegerAttr>().getInt();
|
||||
retShape.insert(retShape.begin() + axis, 1);
|
||||
// infer encoding
|
||||
Attribute argEncoding = argTy.getEncoding();
|
||||
Attribute retEncoding;
|
||||
if (argEncoding) {
|
||||
Dialect &dialect = argEncoding.getDialect();
|
||||
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||
if (inferLayoutInterface
|
||||
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc)
|
||||
.failed())
|
||||
return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp");
|
||||
}
|
||||
// create type
|
||||
auto argEltTy = argTy.getElementType();
|
||||
inferredReturnTypes.push_back(
|
||||
RankedTensorType::get(retShape, argEltTy, retEncoding));
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//-- BroadcastOp --
|
||||
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
auto value = constOperand.getValue();
|
||||
if (auto denseElemsAttr = value.dyn_cast<DenseElementsAttr>()) {
|
||||
if (!denseElemsAttr.isSplat())
|
||||
return {};
|
||||
return SplatElementsAttr::get(shapedType,
|
||||
denseElemsAttr.getSplatValue<Attribute>());
|
||||
} else if (value.getType().isIntOrIndexOrFloat()) {
|
||||
return SplatElementsAttr::get(shapedType, value);
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
71
lib/Dialect/Triton/IR/Traits.cpp
Normal file
71
lib/Dialect/Triton/IR/Traits.cpp
Normal file
@@ -0,0 +1,71 @@
|
||||
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||
|
||||
static mlir::LogicalResult verifySameEncoding(mlir::Type tyA, mlir::Type tyB) {
|
||||
using namespace mlir;
|
||||
auto encA = tyA.dyn_cast<RankedTensorType>();
|
||||
auto encB = tyA.dyn_cast<RankedTensorType>();
|
||||
if (!encA || !encB)
|
||||
return success();
|
||||
return encA.getEncoding() == encB.getEncoding() ? success() : failure();
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::OpTrait::impl::verifySameOperandsAndResultEncoding(Operation *op) {
|
||||
if (failed(verifyAtLeastNOperands(op, 1)) ||
|
||||
failed(verifyAtLeastNResults(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto resultType : op->getResultTypes())
|
||||
if (failed(verifySameEncoding(resultType, type)))
|
||||
return op->emitOpError()
|
||||
<< "requires the same encoding for all operands and results";
|
||||
return verifySameOperandsEncoding(op);
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::OpTrait::impl::verifySameOperandsEncoding(Operation *op) {
|
||||
if (failed(verifyAtLeastNOperands(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
|
||||
if (failed(verifySameEncoding(opType, type)))
|
||||
return op->emitOpError() << "requires the same encoding for all operands";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
|
||||
for (auto opType : op->getOperandTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > maxTensorNumElements)
|
||||
return op->emitError("Maximum allowed number of elements is ")
|
||||
<< maxTensorNumElements << ", but " << *op
|
||||
<< " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule (" << numElements << ")"
|
||||
<< " elements";
|
||||
}
|
||||
}
|
||||
for (auto opType : op->getResultTypes()) {
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t s : tensorType.getShape())
|
||||
numElements *= s;
|
||||
if (numElements > maxTensorNumElements)
|
||||
return op->emitError("Maximum allowed number of elements is ")
|
||||
<< maxTensorNumElements << ", but " << *op
|
||||
<< " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule (" << numElements << ")"
|
||||
<< " elements";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
39
lib/Dialect/Triton/IR/Types.cpp
Normal file
39
lib/Dialect/Triton/IR/Types.cpp
Normal file
@@ -0,0 +1,39 @@
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Triton Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
void TritonDialect::registerTypes() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "triton/Dialect/Triton/IR/Types.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
Type PointerType::parse(AsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
|
||||
Type pointeeType;
|
||||
if (parser.parseType(pointeeType))
|
||||
return Type();
|
||||
|
||||
if (parser.parseGreater())
|
||||
return Type();
|
||||
|
||||
// TODO: also print address space?
|
||||
return PointerType::get(pointeeType, 1);
|
||||
}
|
||||
|
||||
void PointerType::print(AsmPrinter &printer) const {
|
||||
printer << "<" << getPointeeType() << ">";
|
||||
}
|
Reference in New Issue
Block a user