More on type inference & assembly format
This commit is contained in:
@@ -11,15 +11,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
//
|
||||
// Types
|
||||
//
|
||||
// FloatType
|
||||
// def F8 : Type<CPred<"$_self.isa<::mlir::Float8Type()>">,
|
||||
// /*descr*/"8bit float",
|
||||
// /*cppClassName*/"::mlir::triton::Float8Type">;
|
||||
|
||||
// def BF8 : Type<CPred<"$_self.isa<::mlir::triton::BFloat8Type>()">,
|
||||
// /*descr*/"8bit bfloat",
|
||||
// /*cppClassName*/"::mlir::triton::BFloat8Type">;
|
||||
|
||||
class TritonTypeDef<string name, string _mnemonic>
|
||||
: TypeDef<Triton_Dialect, name> {
|
||||
// Used by printer/parser
|
||||
@@ -35,11 +26,8 @@ def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
// IntegerType
|
||||
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
|
||||
def TT_IntegerTensor : TensorOf<[TT_Int]>;
|
||||
def TT_I1Tensor : TensorOf<[I1]>;
|
||||
|
||||
// PointerType
|
||||
// def TT_IsPtrType : CPred<"$_self.isa<::mlir::triton::PointerType>()">;
|
||||
// def TT_AnyPtr : DialectType<Triton_Dialect, TT_IsPtrType, "any Triton pointer type">;
|
||||
def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
|
||||
let summary = "pointer type";
|
||||
|
||||
@@ -141,7 +129,17 @@ def TT_EvictionPolicyAttr : I32EnumAttr<
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> {
|
||||
def TT_LoadOp : TT_Op<"load",
|
||||
[SameOperandsAndResultShape,
|
||||
TypesMatchWith<"infer ptr type from result type",
|
||||
"result", "ptr",
|
||||
"getPointerTypeFromTensor($_self)">,
|
||||
TypesMatchWith<"infer mask type from result type",
|
||||
"result", "mask",
|
||||
"getI1SameShape($_self)">,
|
||||
TypesMatchWith<"infer other type from result type",
|
||||
"result", "other",
|
||||
"$_self">]> {
|
||||
let summary = "load";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other,
|
||||
@@ -157,27 +155,41 @@ def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> {
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
||||
];
|
||||
|
||||
let assemblyFormat = "$ptr`,` $mask`,` $other attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_StoreOp : TT_Op<"store", [SameOperandsShape]> {
|
||||
def TT_StoreOp : TT_Op<"store",
|
||||
[SameOperandsShape,
|
||||
TypesMatchWith<"infer ptr type from value type",
|
||||
"value", "ptr",
|
||||
"getPointerTypeFromTensor($_self)">,
|
||||
TypesMatchWith<"infer mask type from value type",
|
||||
"value", "mask",
|
||||
"getI1SameShape($_self)">]> {
|
||||
let summary = "store";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$ptr, TT_Type:$value, BoolLike:$mask);
|
||||
|
||||
let builders = [
|
||||
// for args with default values
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||
];
|
||||
|
||||
// let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict";
|
||||
let assemblyFormat = "$ptr `,` $value `,` $mask `,` attr-dict `:` type($value)";
|
||||
}
|
||||
|
||||
def TT_GEPOp : TT_Op<"getelementptr", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let arguments = (ins TT_Type:$ptr, TT_IntegerTensor:$offset);
|
||||
def TT_GEPOp : TT_Op<"getelementptr",
|
||||
[NoSideEffect, SameOperandsAndResultShape,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
"result", "ptr", "$_self">,
|
||||
TypesMatchWith<"result shape matches offset shape",
|
||||
"result", "offset",
|
||||
"getI32SameShape($_self)">]> {
|
||||
let arguments = (ins TT_PtrTensor:$ptr, I32Tensor:$offset);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
let results = (outs TT_PtrTensor:$result);
|
||||
|
||||
// let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)";
|
||||
let assemblyFormat = "$ptr `,` $offset `,` attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
|
||||
@@ -190,6 +202,8 @@ def TT_ReshapeOp : TT_Op<"reshape", [SameOperandsAndResultElementType]> {
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
|
||||
@@ -198,6 +212,8 @@ def TT_BroadcastOp : TT_Op<"broadcast", [SameOperandsAndResultElementType]> {
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> {
|
||||
@@ -206,6 +222,8 @@ def TT_CatOp : TT_Op<"cat", [SameOperandsAndResultElementType]> {
|
||||
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
//
|
||||
@@ -223,9 +241,13 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect]> {
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
TypesMatchWith<"result's type matches accumulator's type",
|
||||
"d", "c", "$_self">]> {
|
||||
let summary = "dot";
|
||||
|
||||
let description = [{
|
||||
@@ -235,6 +257,8 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect]> {
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||
}
|
||||
|
||||
// reduction
|
||||
@@ -286,7 +310,7 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw"> {
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr,
|
||||
TT_Type:$val, TT_I1Tensor:$mask);
|
||||
TT_Type:$val, I1Tensor:$mask);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
@@ -6,6 +6,37 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// Type inference
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type);
|
||||
return Type();
|
||||
}
|
||||
|
||||
static Type getI32SameShape(Type type) {
|
||||
auto i32Type = IntegerType::get(type.getContext(), 32);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i32Type);
|
||||
return Type();
|
||||
}
|
||||
|
||||
static Type getPointerTypeFromTensor(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
Type elementType = tensorType.getElementType();
|
||||
auto shape = tensorType.getShape();
|
||||
PointerType ptrType = PointerType::get(elementType, 1);
|
||||
return RankedTensorType::get(shape, ptrType);
|
||||
}
|
||||
return Type();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/ir/Ops.cpp.inc"
|
||||
|
||||
|
Reference in New Issue
Block a user