Use TableGen to define new types

This commit is contained in:
Yan Da
2022-04-08 16:32:46 +08:00
parent 6002340456
commit 13aead4808
6 changed files with 74 additions and 138 deletions

View File

@@ -5,4 +5,6 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(TritonTableGen) add_public_tablegen_target(TritonTableGen)

View File

@@ -35,6 +35,8 @@ def Triton_Dialect : Dialect {
// "func::FuncDialect" // "func::FuncDialect"
]; ];
// let useDefaultTypePrinterParser = 0;
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void registerTypes(); void registerTypes();
}]; }];

View File

@@ -12,13 +12,22 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
// Types // Types
// //
// FloatType // FloatType
def F8 : Type<CPred<"$_self.isa<::mlir::Float8Type()>">, // def F8 : Type<CPred<"$_self.isa<::mlir::Float8Type()>">,
/*descr*/"8bit float", // /*descr*/"8bit float",
/*cppClassName*/"::mlir::triton::Float8Type">; // /*cppClassName*/"::mlir::triton::Float8Type">;
def BF8 : Type<CPred<"$_self.isa<::mlir::triton::BFloat8Type>()">, // def BF8 : Type<CPred<"$_self.isa<::mlir::triton::BFloat8Type>()">,
/*descr*/"8bit bfloat", // /*descr*/"8bit bfloat",
/*cppClassName*/"::mlir::triton::BFloat8Type">; // /*cppClassName*/"::mlir::triton::BFloat8Type">;
class TritonTypeDef<string name, string _mnemonic>
: TypeDef<Triton_Dialect, name> {
// Used by printer/parser
let mnemonic = _mnemonic;
}
def F8 : TritonTypeDef<"Float8", "f8">;
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">; def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>; def TT_FloatTensor : TensorOf<[TT_Float]>;
@@ -29,15 +38,35 @@ def TT_IntegerTensor : TensorOf<[TT_Int]>;
def TT_I1Tensor : TensorOf<[I1]>; def TT_I1Tensor : TensorOf<[I1]>;
// PointerType // PointerType
def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">; // def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">;
def TT_AnyPtr : DialectType<Triton_Dialect, TT_IsPtrType, "any Triton pointer type">; // def TT_AnyPtr : DialectType<Triton_Dialect, TT_IsPtrType, "any Triton pointer type">;
def TT_PtrTensor : TensorOf<[TT_AnyPtr]>; def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
let summary = "pointer type";
let description = [{
TODO
}];
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
let builders = [
TypeBuilderWithInferredContext<(ins
"Type":$pointeeType,
"int":$addressSpace
), [{
return $_get(pointeeType.getContext(), pointeeType, addressSpace);
}]>
];
let skipDefaultBuilders = 1;
}
def TT_PtrTensor : TensorOf<[TT_Pointer]>;
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>; def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>;
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>; def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor, def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
TT_AnyPtr, TT_PtrTensor]>; TT_Pointer, TT_PtrTensor]>;
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
// //
@@ -147,6 +176,8 @@ def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]
let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset); let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset);
let results = (outs TT_Type:$result); let results = (outs TT_Type:$result);
// let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)";
} }
@@ -273,7 +304,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
return $old return $old
}]; }];
let arguments = (ins TT_AnyPtr:$ptr, TT_Type:$cmp, TT_Type:$val); let arguments = (ins TT_Pointer:$ptr, TT_Type:$cmp, TT_Type:$val);
let results = (outs TT_Type:$result); let results = (outs TT_Type:$result);
} }

View File

@@ -4,43 +4,7 @@
#include "mlir/IR/TypeSupport.h" #include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
namespace mlir { #define GET_TYPEDEF_CLASSES
namespace triton { #include "triton/ir/Types.h.inc"
namespace detail {
struct PointerTypeStorage;
} // namespace detail
// TODO: Should be base class be FloatType?
class Float8Type : public Type::TypeBase<Float8Type, Type, TypeStorage> {
public:
using Base::Base;
static Float8Type get(MLIRContext *context);
};
class BFloat8Type : public Type::TypeBase<BFloat8Type, Type, TypeStorage> {
public:
using Base::Base;
static BFloat8Type get(MLIRContext *context);
};
class PointerType : public Type::TypeBase<PointerType, Type,
detail::PointerTypeStorage> {
public:
using Base::Base;
static PointerType get(Type pointeeType);
static PointerType get(Type pointeeType, unsigned addressSpace);
Type getPointeeType() const;
unsigned getAddressSpace() const;
};
} // namespace triton
} // namespace mlir
#endif // TRITON_IR_TYPES_H_ #endif // TRITON_IR_TYPES_H_

View File

@@ -23,49 +23,3 @@ void TritonDialect::initialize() {
// We can also add interface here. // We can also add interface here.
} }
//===----------------------------------------------------------------------===//
// Type Parsing
//===----------------------------------------------------------------------===//
// pointer-type ::= `!triton.ptr<` element-type ` >`
static Type parsePointerType(TritonDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
return Type();
Type pointeeType;
if (parser.parseType(pointeeType))
return Type();
if (parser.parseGreater())
return Type();
return PointerType::get(pointeeType);
}
// trtion-type ::= pointer-type
Type TritonDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
if (keyword == "ptr")
return parsePointerType(*this, parser);
parser.emitError(parser.getNameLoc(), "unknown Triton type: ") << keyword;
return Type();
}
//===----------------------------------------------------------------------===//
// Type Printing
//===----------------------------------------------------------------------===//
static void print(PointerType type, DialectAsmPrinter &os) {
os << "ptr<" << type.getPointeeType() << ">";
}
void TritonDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<PointerType>( [&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled Triton type"); });
}

View File

@@ -1,55 +1,38 @@
#include "triton/ir/Dialect.h" #include "triton/ir/Dialect.h"
#include "triton/ir/Types.h" #include "triton/ir/Types.h"
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
using namespace mlir; using namespace mlir;
using namespace mlir::triton; using namespace mlir::triton;
// F8 & BF8 #define GET_TYPEDEF_CLASSES
Float8Type Float8Type::get(MLIRContext *context) { #include "triton/ir/Types.cpp.inc"
return Base::get(context);
}
BFloat8Type BFloat8Type::get(MLIRContext *context) {
return Base::get(context);
}
//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
struct triton::detail::PointerTypeStorage : public TypeStorage {
using KeyTy = std::pair<Type, unsigned>;
static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<PointerTypeStorage>()) PointerTypeStorage(key);
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(pointeeType, addressSpace);
}
PointerTypeStorage(const KeyTy &key)
: pointeeType(key.first), addressSpace(key.second) {}
Type pointeeType;
unsigned addressSpace;
};
PointerType PointerType::get(Type pointeeType) {
return Base::get(pointeeType.getContext(), pointeeType, 0);
}
PointerType PointerType::get(Type pointeeType, unsigned addressSpace) {
return Base::get(pointeeType.getContext(), pointeeType, addressSpace);
}
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
unsigned PointerType::getAddressSpace() const { return getImpl()->addressSpace; }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Triton Dialect // Triton Dialect
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void TritonDialect::registerTypes() { void TritonDialect::registerTypes() {
addTypes<Float8Type, BFloat8Type, PointerType>(); addTypes<
#define GET_TYPEDEF_LIST
#include "triton/ir/Types.cpp.inc"
>();
}
Type PointerType::parse(AsmParser &parser) {
if (parser.parseLess())
return Type();
Type pointeeType;
if (parser.parseType(pointeeType))
return Type();
if (parser.parseGreater())
return Type();
return PointerType::get(pointeeType, 0);
}
void PointerType::print(AsmPrinter &printer) const {
printer << "<" << getPointeeType() << ">";
} }