Use TableGen to define new types
This commit is contained in:
@@ -5,4 +5,6 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
|||||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||||
|
mlir_tablegen(Types.h.inc -gen-typedef-decls)
|
||||||
|
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
|
||||||
add_public_tablegen_target(TritonTableGen)
|
add_public_tablegen_target(TritonTableGen)
|
||||||
|
@@ -35,6 +35,8 @@ def Triton_Dialect : Dialect {
|
|||||||
// "func::FuncDialect"
|
// "func::FuncDialect"
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// let useDefaultTypePrinterParser = 0;
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
void registerTypes();
|
void registerTypes();
|
||||||
}];
|
}];
|
||||||
|
@@ -12,13 +12,22 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
|||||||
// Types
|
// Types
|
||||||
//
|
//
|
||||||
// FloatType
|
// FloatType
|
||||||
def F8 : Type<CPred<"$_self.isa<::mlir::Float8Type()>">,
|
// def F8 : Type<CPred<"$_self.isa<::mlir::Float8Type()>">,
|
||||||
/*descr*/"8bit float",
|
// /*descr*/"8bit float",
|
||||||
/*cppClassName*/"::mlir::triton::Float8Type">;
|
// /*cppClassName*/"::mlir::triton::Float8Type">;
|
||||||
|
|
||||||
def BF8 : Type<CPred<"$_self.isa<::mlir::triton::BFloat8Type>()">,
|
// def BF8 : Type<CPred<"$_self.isa<::mlir::triton::BFloat8Type>()">,
|
||||||
/*descr*/"8bit bfloat",
|
// /*descr*/"8bit bfloat",
|
||||||
/*cppClassName*/"::mlir::triton::BFloat8Type">;
|
// /*cppClassName*/"::mlir::triton::BFloat8Type">;
|
||||||
|
|
||||||
|
class TritonTypeDef<string name, string _mnemonic>
|
||||||
|
: TypeDef<Triton_Dialect, name> {
|
||||||
|
// Used by printer/parser
|
||||||
|
let mnemonic = _mnemonic;
|
||||||
|
}
|
||||||
|
|
||||||
|
def F8 : TritonTypeDef<"Float8", "f8">;
|
||||||
|
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
|
||||||
|
|
||||||
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
|
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
|
||||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||||
@@ -29,15 +38,35 @@ def TT_IntegerTensor : TensorOf<[TT_Int]>;
|
|||||||
def TT_I1Tensor : TensorOf<[I1]>;
|
def TT_I1Tensor : TensorOf<[I1]>;
|
||||||
|
|
||||||
// PointerType
|
// PointerType
|
||||||
def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">;
|
// def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">;
|
||||||
def TT_AnyPtr : DialectType<Triton_Dialect, TT_IsPtrType, "any Triton pointer type">;
|
// def TT_AnyPtr : DialectType<Triton_Dialect, TT_IsPtrType, "any Triton pointer type">;
|
||||||
def TT_PtrTensor : TensorOf<[TT_AnyPtr]>;
|
def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
|
||||||
|
let summary = "pointer type";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
TODO
|
||||||
|
}];
|
||||||
|
|
||||||
|
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
TypeBuilderWithInferredContext<(ins
|
||||||
|
"Type":$pointeeType,
|
||||||
|
"int":$addressSpace
|
||||||
|
), [{
|
||||||
|
return $_get(pointeeType.getContext(), pointeeType, addressSpace);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
|
|
||||||
|
let skipDefaultBuilders = 1;
|
||||||
|
}
|
||||||
|
def TT_PtrTensor : TensorOf<[TT_Pointer]>;
|
||||||
|
|
||||||
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>;
|
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>;
|
||||||
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
|
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
|
||||||
|
|
||||||
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
|
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
|
||||||
TT_AnyPtr, TT_PtrTensor]>;
|
TT_Pointer, TT_PtrTensor]>;
|
||||||
|
|
||||||
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
||||||
//
|
//
|
||||||
@@ -147,6 +176,8 @@ def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]
|
|||||||
let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset);
|
let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset);
|
||||||
|
|
||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
|
|
||||||
|
// let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -273,7 +304,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
|||||||
return $old
|
return $old
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins TT_AnyPtr:$ptr, TT_Type:$cmp, TT_Type:$val);
|
let arguments = (ins TT_Pointer:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||||
|
|
||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
}
|
}
|
||||||
|
@@ -4,43 +4,7 @@
|
|||||||
#include "mlir/IR/TypeSupport.h"
|
#include "mlir/IR/TypeSupport.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
|
||||||
namespace mlir {
|
#define GET_TYPEDEF_CLASSES
|
||||||
namespace triton {
|
#include "triton/ir/Types.h.inc"
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
struct PointerTypeStorage;
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
// TODO: Should be base class be FloatType?
|
|
||||||
class Float8Type : public Type::TypeBase<Float8Type, Type, TypeStorage> {
|
|
||||||
public:
|
|
||||||
using Base::Base;
|
|
||||||
|
|
||||||
static Float8Type get(MLIRContext *context);
|
|
||||||
};
|
|
||||||
|
|
||||||
class BFloat8Type : public Type::TypeBase<BFloat8Type, Type, TypeStorage> {
|
|
||||||
public:
|
|
||||||
using Base::Base;
|
|
||||||
|
|
||||||
static BFloat8Type get(MLIRContext *context);
|
|
||||||
};
|
|
||||||
|
|
||||||
class PointerType : public Type::TypeBase<PointerType, Type,
|
|
||||||
detail::PointerTypeStorage> {
|
|
||||||
public:
|
|
||||||
using Base::Base;
|
|
||||||
|
|
||||||
static PointerType get(Type pointeeType);
|
|
||||||
|
|
||||||
static PointerType get(Type pointeeType, unsigned addressSpace);
|
|
||||||
|
|
||||||
Type getPointeeType() const;
|
|
||||||
|
|
||||||
unsigned getAddressSpace() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace triton
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // TRITON_IR_TYPES_H_
|
#endif // TRITON_IR_TYPES_H_
|
||||||
|
@@ -23,49 +23,3 @@ void TritonDialect::initialize() {
|
|||||||
|
|
||||||
// We can also add interface here.
|
// We can also add interface here.
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Type Parsing
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// pointer-type ::= `!triton.ptr<` element-type ` >`
|
|
||||||
static Type parsePointerType(TritonDialect const &dialect,
|
|
||||||
DialectAsmParser &parser) {
|
|
||||||
if (parser.parseLess())
|
|
||||||
return Type();
|
|
||||||
|
|
||||||
|
|
||||||
Type pointeeType;
|
|
||||||
if (parser.parseType(pointeeType))
|
|
||||||
return Type();
|
|
||||||
|
|
||||||
if (parser.parseGreater())
|
|
||||||
return Type();
|
|
||||||
|
|
||||||
return PointerType::get(pointeeType);
|
|
||||||
}
|
|
||||||
|
|
||||||
// trtion-type ::= pointer-type
|
|
||||||
Type TritonDialect::parseType(DialectAsmParser &parser) const {
|
|
||||||
StringRef keyword;
|
|
||||||
if (parser.parseKeyword(&keyword))
|
|
||||||
return Type();
|
|
||||||
|
|
||||||
if (keyword == "ptr")
|
|
||||||
return parsePointerType(*this, parser);
|
|
||||||
|
|
||||||
parser.emitError(parser.getNameLoc(), "unknown Triton type: ") << keyword;
|
|
||||||
return Type();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Type Printing
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
static void print(PointerType type, DialectAsmPrinter &os) {
|
|
||||||
os << "ptr<" << type.getPointeeType() << ">";
|
|
||||||
}
|
|
||||||
|
|
||||||
void TritonDialect::printType(Type type, DialectAsmPrinter &os) const {
|
|
||||||
TypeSwitch<Type>(type)
|
|
||||||
.Case<PointerType>( [&](auto type) { print(type, os); })
|
|
||||||
.Default([](Type) { llvm_unreachable("unhandled Triton type"); });
|
|
||||||
}
|
|
||||||
|
@@ -1,55 +1,38 @@
|
|||||||
#include "triton/ir/Dialect.h"
|
#include "triton/ir/Dialect.h"
|
||||||
#include "triton/ir/Types.h"
|
#include "triton/ir/Types.h"
|
||||||
|
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
|
||||||
|
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
|
||||||
// F8 & BF8
|
#define GET_TYPEDEF_CLASSES
|
||||||
Float8Type Float8Type::get(MLIRContext *context) {
|
#include "triton/ir/Types.cpp.inc"
|
||||||
return Base::get(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
BFloat8Type BFloat8Type::get(MLIRContext *context) {
|
|
||||||
return Base::get(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// PointerType
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
struct triton::detail::PointerTypeStorage : public TypeStorage {
|
|
||||||
using KeyTy = std::pair<Type, unsigned>;
|
|
||||||
|
|
||||||
static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
|
|
||||||
const KeyTy &key) {
|
|
||||||
return new (allocator.allocate<PointerTypeStorage>()) PointerTypeStorage(key);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator==(const KeyTy &key) const {
|
|
||||||
return key == KeyTy(pointeeType, addressSpace);
|
|
||||||
}
|
|
||||||
|
|
||||||
PointerTypeStorage(const KeyTy &key)
|
|
||||||
: pointeeType(key.first), addressSpace(key.second) {}
|
|
||||||
|
|
||||||
Type pointeeType;
|
|
||||||
unsigned addressSpace;
|
|
||||||
};
|
|
||||||
|
|
||||||
PointerType PointerType::get(Type pointeeType) {
|
|
||||||
return Base::get(pointeeType.getContext(), pointeeType, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
PointerType PointerType::get(Type pointeeType, unsigned addressSpace) {
|
|
||||||
return Base::get(pointeeType.getContext(), pointeeType, addressSpace);
|
|
||||||
}
|
|
||||||
|
|
||||||
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
|
|
||||||
|
|
||||||
unsigned PointerType::getAddressSpace() const { return getImpl()->addressSpace; }
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Triton Dialect
|
// Triton Dialect
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
void TritonDialect::registerTypes() {
|
void TritonDialect::registerTypes() {
|
||||||
addTypes<Float8Type, BFloat8Type, PointerType>();
|
addTypes<
|
||||||
|
#define GET_TYPEDEF_LIST
|
||||||
|
#include "triton/ir/Types.cpp.inc"
|
||||||
|
>();
|
||||||
|
}
|
||||||
|
|
||||||
|
Type PointerType::parse(AsmParser &parser) {
|
||||||
|
if (parser.parseLess())
|
||||||
|
return Type();
|
||||||
|
|
||||||
|
Type pointeeType;
|
||||||
|
if (parser.parseType(pointeeType))
|
||||||
|
return Type();
|
||||||
|
|
||||||
|
if (parser.parseGreater())
|
||||||
|
return Type();
|
||||||
|
|
||||||
|
return PointerType::get(pointeeType, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PointerType::print(AsmPrinter &printer) const {
|
||||||
|
printer << "<" << getPointeeType() << ">";
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user