Use TableGen to define new types

This commit is contained in:
Yan Da
2022-04-08 16:32:46 +08:00
parent 6002340456
commit 13aead4808
6 changed files with 74 additions and 138 deletions

View File

@@ -23,49 +23,3 @@ void TritonDialect::initialize() {
// We can also add interface here.
}
//===----------------------------------------------------------------------===//
// Type Parsing
//===----------------------------------------------------------------------===//
// pointer-type ::= `!triton.ptr<` element-type ` >`
static Type parsePointerType(TritonDialect const &dialect,
DialectAsmParser &parser) {
if (parser.parseLess())
return Type();
Type pointeeType;
if (parser.parseType(pointeeType))
return Type();
if (parser.parseGreater())
return Type();
return PointerType::get(pointeeType);
}
// trtion-type ::= pointer-type
Type TritonDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
if (keyword == "ptr")
return parsePointerType(*this, parser);
parser.emitError(parser.getNameLoc(), "unknown Triton type: ") << keyword;
return Type();
}
//===----------------------------------------------------------------------===//
// Type Printing
//===----------------------------------------------------------------------===//
static void print(PointerType type, DialectAsmPrinter &os) {
os << "ptr<" << type.getPointeeType() << ">";
}
void TritonDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<PointerType>( [&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled Triton type"); });
}

View File

@@ -1,55 +1,38 @@
#include "triton/ir/Dialect.h"
#include "triton/ir/Types.h"
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
using namespace mlir;
using namespace mlir::triton;
// F8 & BF8
Float8Type Float8Type::get(MLIRContext *context) {
return Base::get(context);
}
BFloat8Type BFloat8Type::get(MLIRContext *context) {
return Base::get(context);
}
//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
struct triton::detail::PointerTypeStorage : public TypeStorage {
using KeyTy = std::pair<Type, unsigned>;
static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<PointerTypeStorage>()) PointerTypeStorage(key);
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(pointeeType, addressSpace);
}
PointerTypeStorage(const KeyTy &key)
: pointeeType(key.first), addressSpace(key.second) {}
Type pointeeType;
unsigned addressSpace;
};
PointerType PointerType::get(Type pointeeType) {
return Base::get(pointeeType.getContext(), pointeeType, 0);
}
PointerType PointerType::get(Type pointeeType, unsigned addressSpace) {
return Base::get(pointeeType.getContext(), pointeeType, addressSpace);
}
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
unsigned PointerType::getAddressSpace() const { return getImpl()->addressSpace; }
#define GET_TYPEDEF_CLASSES
#include "triton/ir/Types.cpp.inc"
//===----------------------------------------------------------------------===//
// Triton Dialect
//===----------------------------------------------------------------------===//
void TritonDialect::registerTypes() {
addTypes<Float8Type, BFloat8Type, PointerType>();
addTypes<
#define GET_TYPEDEF_LIST
#include "triton/ir/Types.cpp.inc"
>();
}
Type PointerType::parse(AsmParser &parser) {
if (parser.parseLess())
return Type();
Type pointeeType;
if (parser.parseType(pointeeType))
return Type();
if (parser.parseGreater())
return Type();
return PointerType::get(pointeeType, 0);
}
void PointerType::print(AsmPrinter &printer) const {
printer << "<" << getPointeeType() << ">";
}