Use TableGen to define new types

This commit is contained in:
Yan Da
2022-04-08 16:32:46 +08:00
parent 6002340456
commit 13aead4808
6 changed files with 74 additions and 138 deletions

View File

@@ -5,4 +5,6 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(Types.h.inc -gen-typedef-decls)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(TritonTableGen)

View File

@@ -35,6 +35,8 @@ def Triton_Dialect : Dialect {
// "func::FuncDialect"
];
// let useDefaultTypePrinterParser = 0;
let extraClassDeclaration = [{
void registerTypes();
}];

View File

@@ -12,13 +12,22 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
// Types
//
// FloatType
def F8 : Type<CPred<"$_self.isa<::mlir::Float8Type()>">,
/*descr*/"8bit float",
/*cppClassName*/"::mlir::triton::Float8Type">;
// def F8 : Type<CPred<"$_self.isa<::mlir::Float8Type()>">,
// /*descr*/"8bit float",
// /*cppClassName*/"::mlir::triton::Float8Type">;
def BF8 : Type<CPred<"$_self.isa<::mlir::triton::BFloat8Type>()">,
/*descr*/"8bit bfloat",
/*cppClassName*/"::mlir::triton::BFloat8Type">;
// def BF8 : Type<CPred<"$_self.isa<::mlir::triton::BFloat8Type>()">,
// /*descr*/"8bit bfloat",
// /*cppClassName*/"::mlir::triton::BFloat8Type">;
class TritonTypeDef<string name, string _mnemonic>
: TypeDef<Triton_Dialect, name> {
// Used by printer/parser
let mnemonic = _mnemonic;
}
def F8 : TritonTypeDef<"Float8", "f8">;
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
@@ -29,15 +38,35 @@ def TT_IntegerTensor : TensorOf<[TT_Int]>;
def TT_I1Tensor : TensorOf<[I1]>;
// PointerType
def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">;
def TT_AnyPtr : DialectType<Triton_Dialect, TT_IsPtrType, "any Triton pointer type">;
def TT_PtrTensor : TensorOf<[TT_AnyPtr]>;
// def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">;
// def TT_AnyPtr : DialectType<Triton_Dialect, TT_IsPtrType, "any Triton pointer type">;
def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
let summary = "pointer type";
let description = [{
TODO
}];
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
let builders = [
TypeBuilderWithInferredContext<(ins
"Type":$pointeeType,
"int":$addressSpace
), [{
return $_get(pointeeType.getContext(), pointeeType, addressSpace);
}]>
];
let skipDefaultBuilders = 1;
}
def TT_PtrTensor : TensorOf<[TT_Pointer]>;
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>;
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
TT_AnyPtr, TT_PtrTensor]>;
TT_Pointer, TT_PtrTensor]>;
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
//
@@ -147,6 +176,8 @@ def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]
let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset);
let results = (outs TT_Type:$result);
// let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)";
}
@@ -273,7 +304,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
return $old
}];
let arguments = (ins TT_AnyPtr:$ptr, TT_Type:$cmp, TT_Type:$val);
let arguments = (ins TT_Pointer:$ptr, TT_Type:$cmp, TT_Type:$val);
let results = (outs TT_Type:$result);
}

View File

@@ -4,43 +4,7 @@
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
namespace mlir {
namespace triton {
namespace detail {
struct PointerTypeStorage;
} // namespace detail
// TODO: Should be base class be FloatType?
class Float8Type : public Type::TypeBase<Float8Type, Type, TypeStorage> {
public:
using Base::Base;
static Float8Type get(MLIRContext *context);
};
class BFloat8Type : public Type::TypeBase<BFloat8Type, Type, TypeStorage> {
public:
using Base::Base;
static BFloat8Type get(MLIRContext *context);
};
class PointerType : public Type::TypeBase<PointerType, Type,
detail::PointerTypeStorage> {
public:
using Base::Base;
static PointerType get(Type pointeeType);
static PointerType get(Type pointeeType, unsigned addressSpace);
Type getPointeeType() const;
unsigned getAddressSpace() const;
};
} // namespace triton
} // namespace mlir
#define GET_TYPEDEF_CLASSES
#include "triton/ir/Types.h.inc"
#endif // TRITON_IR_TYPES_H_