From 13aead480825f07f5f70120e853431c14029872c Mon Sep 17 00:00:00 2001 From: Yan Da Date: Fri, 8 Apr 2022 16:32:46 +0800 Subject: [PATCH] Use TableGen to define new types --- include/triton/ir/CMakeLists.txt | 2 + include/triton/ir/TritonDialect.td | 2 + include/triton/ir/TritonOps.td | 53 ++++++++++++++++++----- include/triton/ir/Types.h | 40 +---------------- lib/ir/Dialect.cpp | 46 -------------------- lib/ir/Types.cpp | 69 +++++++++++------------------- 6 files changed, 74 insertions(+), 138 deletions(-) diff --git a/include/triton/ir/CMakeLists.txt b/include/triton/ir/CMakeLists.txt index 2036df67a..46573add6 100644 --- a/include/triton/ir/CMakeLists.txt +++ b/include/triton/ir/CMakeLists.txt @@ -5,4 +5,6 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) add_public_tablegen_target(TritonTableGen) diff --git a/include/triton/ir/TritonDialect.td b/include/triton/ir/TritonDialect.td index 2b2a57124..100982365 100644 --- a/include/triton/ir/TritonDialect.td +++ b/include/triton/ir/TritonDialect.td @@ -35,6 +35,8 @@ def Triton_Dialect : Dialect { // "func::FuncDialect" ]; + // let useDefaultTypePrinterParser = 0; + let extraClassDeclaration = [{ void registerTypes(); }]; diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td index 36fbdd2d1..6ae554928 100644 --- a/include/triton/ir/TritonOps.td +++ b/include/triton/ir/TritonOps.td @@ -12,13 +12,22 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType // Types // // FloatType -def F8 : Type">, - /*descr*/"8bit float", - /*cppClassName*/"::mlir::triton::Float8Type">; +// def F8 : Type">, +// /*descr*/"8bit float", +// /*cppClassName*/"::mlir::triton::Float8Type">; -def BF8 : Type()">, - /*descr*/"8bit bfloat", - /*cppClassName*/"::mlir::triton::BFloat8Type">; +// def BF8 : Type()">, +// /*descr*/"8bit bfloat", +// /*cppClassName*/"::mlir::triton::BFloat8Type">; + +class TritonTypeDef + : TypeDef { + // Used by printer/parser + let mnemonic = _mnemonic; +} + +def F8 : TritonTypeDef<"Float8", "f8">; +def BF8 : TritonTypeDef<"BFloat8", "bf8">; def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">; def TT_FloatTensor : TensorOf<[TT_Float]>; @@ -29,15 +38,35 @@ def TT_IntegerTensor : TensorOf<[TT_Int]>; def TT_I1Tensor : TensorOf<[I1]>; // PointerType -def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">; -def TT_AnyPtr : DialectType; -def TT_PtrTensor : TensorOf<[TT_AnyPtr]>; +// def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">; +// def TT_AnyPtr : DialectType; +def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> { + let summary = "pointer type"; + + let description = [{ + TODO + }]; + + let parameters = (ins "Type":$pointeeType, "int":$addressSpace); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$pointeeType, + "int":$addressSpace + ), [{ + return $_get(pointeeType.getContext(), pointeeType, addressSpace); + }]> + ]; + + let skipDefaultBuilders = 1; +} +def TT_PtrTensor : TensorOf<[TT_Pointer]>; def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>; def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>; def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor, - TT_AnyPtr, TT_PtrTensor]>; + TT_Pointer, TT_PtrTensor]>; def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; // @@ -147,6 +176,8 @@ def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape] let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset); let results = (outs TT_Type:$result); + + // let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)"; } @@ -273,7 +304,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> { return $old }]; - let arguments = (ins TT_AnyPtr:$ptr, TT_Type:$cmp, TT_Type:$val); + let arguments = (ins TT_Pointer:$ptr, TT_Type:$cmp, TT_Type:$val); let results = (outs TT_Type:$result); } diff --git a/include/triton/ir/Types.h b/include/triton/ir/Types.h index d2e59eb8d..2f94cf573 100644 --- a/include/triton/ir/Types.h +++ b/include/triton/ir/Types.h @@ -4,43 +4,7 @@ #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" -namespace mlir { -namespace triton { - -namespace detail { -struct PointerTypeStorage; -} // namespace detail - -// TODO: Should be base class be FloatType? -class Float8Type : public Type::TypeBase { -public: - using Base::Base; - - static Float8Type get(MLIRContext *context); -}; - -class BFloat8Type : public Type::TypeBase { -public: - using Base::Base; - - static BFloat8Type get(MLIRContext *context); -}; - -class PointerType : public Type::TypeBase { -public: - using Base::Base; - - static PointerType get(Type pointeeType); - - static PointerType get(Type pointeeType, unsigned addressSpace); - - Type getPointeeType() const; - - unsigned getAddressSpace() const; -}; - -} // namespace triton -} // namespace mlir +#define GET_TYPEDEF_CLASSES +#include "triton/ir/Types.h.inc" #endif // TRITON_IR_TYPES_H_ diff --git a/lib/ir/Dialect.cpp b/lib/ir/Dialect.cpp index be34f094c..bd5dbea78 100644 --- a/lib/ir/Dialect.cpp +++ b/lib/ir/Dialect.cpp @@ -23,49 +23,3 @@ void TritonDialect::initialize() { // We can also add interface here. } - -//===----------------------------------------------------------------------===// -// Type Parsing -//===----------------------------------------------------------------------===// -// pointer-type ::= `!triton.ptr<` element-type ` >` -static Type parsePointerType(TritonDialect const &dialect, - DialectAsmParser &parser) { - if (parser.parseLess()) - return Type(); - - - Type pointeeType; - if (parser.parseType(pointeeType)) - return Type(); - - if (parser.parseGreater()) - return Type(); - - return PointerType::get(pointeeType); -} - -// trtion-type ::= pointer-type -Type TritonDialect::parseType(DialectAsmParser &parser) const { - StringRef keyword; - if (parser.parseKeyword(&keyword)) - return Type(); - - if (keyword == "ptr") - return parsePointerType(*this, parser); - - parser.emitError(parser.getNameLoc(), "unknown Triton type: ") << keyword; - return Type(); -} - -//===----------------------------------------------------------------------===// -// Type Printing -//===----------------------------------------------------------------------===// -static void print(PointerType type, DialectAsmPrinter &os) { - os << "ptr<" << type.getPointeeType() << ">"; -} - -void TritonDialect::printType(Type type, DialectAsmPrinter &os) const { - TypeSwitch(type) - .Case( [&](auto type) { print(type, os); }) - .Default([](Type) { llvm_unreachable("unhandled Triton type"); }); -} diff --git a/lib/ir/Types.cpp b/lib/ir/Types.cpp index ce5b014a2..ff2db0b1b 100644 --- a/lib/ir/Types.cpp +++ b/lib/ir/Types.cpp @@ -1,55 +1,38 @@ #include "triton/ir/Dialect.h" #include "triton/ir/Types.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` using namespace mlir; using namespace mlir::triton; -// F8 & BF8 -Float8Type Float8Type::get(MLIRContext *context) { - return Base::get(context); -} - -BFloat8Type BFloat8Type::get(MLIRContext *context) { - return Base::get(context); -} - -//===----------------------------------------------------------------------===// -// PointerType -//===----------------------------------------------------------------------===// -struct triton::detail::PointerTypeStorage : public TypeStorage { - using KeyTy = std::pair; - - static PointerTypeStorage *construct(TypeStorageAllocator &allocator, - const KeyTy &key) { - return new (allocator.allocate()) PointerTypeStorage(key); - } - - bool operator==(const KeyTy &key) const { - return key == KeyTy(pointeeType, addressSpace); - } - - PointerTypeStorage(const KeyTy &key) - : pointeeType(key.first), addressSpace(key.second) {} - - Type pointeeType; - unsigned addressSpace; -}; - -PointerType PointerType::get(Type pointeeType) { - return Base::get(pointeeType.getContext(), pointeeType, 0); -} - -PointerType PointerType::get(Type pointeeType, unsigned addressSpace) { - return Base::get(pointeeType.getContext(), pointeeType, addressSpace); -} - -Type PointerType::getPointeeType() const { return getImpl()->pointeeType; } - -unsigned PointerType::getAddressSpace() const { return getImpl()->addressSpace; } +#define GET_TYPEDEF_CLASSES +#include "triton/ir/Types.cpp.inc" //===----------------------------------------------------------------------===// // Triton Dialect //===----------------------------------------------------------------------===// void TritonDialect::registerTypes() { - addTypes(); + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/ir/Types.cpp.inc" + >(); +} + +Type PointerType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + Type pointeeType; + if (parser.parseType(pointeeType)) + return Type(); + + if (parser.parseGreater()) + return Type(); + + return PointerType::get(pointeeType, 0); +} + +void PointerType::print(AsmPrinter &printer) const { + printer << "<" << getPointeeType() << ">"; }